diff --git a/README.md b/README.md index 31365da89169cfe2be58de89a574ae4b69e40224..3467a31f23b7f770d32afb91cb62d5207ccf3d62 100644 --- a/README.md +++ b/README.md @@ -20,4 +20,60 @@ and [Network In Network (Lin et al., 2014)](https://arxiv.org/abs/1312.4400). add it to `src/join_settings.py` in the hourly data section. Replace the `TOAR_SERVICE_URL` and the `Authorization` value. To make sure, that this **sensitive** data is not uploaded to the remote server, use the following command to prevent git from tracking this file: `git update-index --assume-unchanged src/join_settings.py -` \ No newline at end of file +` + +# Customise your experiment + +This section summarises which parameters can be customised for a training. + +## Transformation + +There are two different approaches (called scopes) to transform the data: +1) `station`: transform data for each station independently (somehow like batch normalisation) +1) `data`: transform all data of each station with shared metrics + +Transformation must be set by the `transformation` attribute. If `transformation = None` is given to `ExperimentSetup`, +data is not transformed at all. For all other setups, use the following dictionary structure to specify the +transformation. +``` +transformation = {"scope": <...>, + "method": <...>, + "mean": <...>, + "std": <...>} +ExperimentSetup(..., transformation=transformation, ...) +``` + +### scopes + +**station**: mean and std are not used + +**data**: either provide already calculated values for mean and std (if required by transformation method), or choose +from different calculation schemes, explained in the mean and std section. + +### supported transformation methods +Currently supported methods are: +* standardise (default, if method is not given) +* centre + +### mean and std +`"mean"="accurate"`: calculate the accurate values of mean and std (depending on method) by using all data. Although, +this method is accurate, it may take some time for the calculation. Furthermore, this could potentially lead to memory +issue (not explored yet, but could appear for a very big amount of data) + +`"mean"="estimate"`: estimate mean and std (depending on method). For each station, mean and std are calculated and +afterwards aggregated using the mean value over all station-wise metrics. This method is less accurate, especially +regarding the std calculation but therefore much faster. + +We recommend to use the later method *estimate* because of following reasons: +* much faster calculation +* real accuracy of mean and std is less important, because it is "just" a transformation / scaling +* accuracy of mean is almost as high as in the *accurate* case, because of +$\bar{x_{ij}} = \bar{\left(\bar{x_i}\right)_j}$. The only difference is, that in the *estimate* case, each mean is +equally weighted for each station independently of the actual data count of the station. +* accuracy of std is lower for *estimate* because of $\var{x_{ij}} \ne \bar{\left(\var{x_i}\right)_j}$, but still the mean of all +station-wise std is a decent estimate of the true std. + +`"mean"=<value, e.g. xr.DataArray>`: If mean and std are already calculated or shall be set manually, just add the +scaling values instead of the calculation method. For method *centre*, std can still be None, but is required for the +*standardise* method. **Important**: Format of given values **must** match internal data format of DataPreparation +class: `xr.DataArray` with `dims=["variables"]` and one value for each variable. \ No newline at end of file diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index c6f38a6f0e70518956bcbbd51a6fdfc1a1e7849f..b1624410e746ab779b20a60d6a7d19b4ae3b1267 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -12,11 +12,11 @@ import numpy as np class Distributor(keras.utils.Sequence): def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256, - fit_call: bool = True): + permute_data: bool = False): self.generator = generator self.model = model self.batch_size = batch_size - self.fit_call = fit_call + self.do_data_permutation = permute_data def _get_model_rank(self): mod_out = self.model.output_shape @@ -33,6 +33,16 @@ class Distributor(keras.utils.Sequence): def _get_number_of_mini_batches(self, values): return math.ceil(values[0].shape[0] / self.batch_size) + def _permute_data(self, x, y): + """ + Permute inputs x and labels y + """ + if self.do_data_permutation: + p = np.random.permutation(len(x)) # equiv to .shape[0] + x = x[p] + y = y[p] + return x, y + def distribute_on_batches(self, fit_call=True): while True: for k, v in enumerate(self.generator): @@ -42,6 +52,8 @@ class Distributor(keras.utils.Sequence): num_mini_batches = self._get_number_of_mini_batches(v) x_total = np.copy(v[0]) y_total = np.copy(v[1]) + # permute order for mini-batches + x_total, y_total = self._permute_data(x_total, y_total) for prev, curr in enumerate(range(1, num_mini_batches+1)): x = x_total[prev*self.batch_size:curr*self.batch_size, ...] y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)] diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 19a94fbb9dbbc8f382a225c852f34971a98395b8..cf4a0e3d8bc2483e787dfea31c5c9a32fb437fe1 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -2,8 +2,9 @@ __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-11-07' import os -from typing import Union, List, Tuple, Any +from typing import Union, List, Tuple, Any, Dict +import dask.array as da import keras import xarray as xr import pickle @@ -11,6 +12,7 @@ import logging from src import helpers from src.data_handling.data_preparation import DataPrep +from src.join import EmptyQueryResult class DataGenerator(keras.utils.Sequence): @@ -25,7 +27,7 @@ class DataGenerator(keras.utils.Sequence): def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str], interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None, interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, - window_lead_time: int = 4, transform_method: str = "standardise", **kwargs): + window_lead_time: int = 4, transformation: Dict = None, **kwargs): self.data_path = os.path.abspath(data_path) self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") if not os.path.exists(self.data_path_tmp): @@ -41,8 +43,8 @@ class DataGenerator(keras.utils.Sequence): self.limit_nan_fill = limit_nan_fill self.window_history_size = window_history_size self.window_lead_time = window_lead_time - self.transform_method = transform_method self.kwargs = kwargs + self.transformation = self.setup_transformation(transformation) def __repr__(self): """ @@ -94,30 +96,100 @@ class DataGenerator(keras.utils.Sequence): data = self.get_data_generator(key=item) return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window") - def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep: + def setup_transformation(self, transformation): + if transformation is None: + return + transformation = transformation.copy() + scope = transformation.get("scope", "station") + method = transformation.get("method", "standardise") + mean = transformation.get("mean", None) + std = transformation.get("std", None) + if scope == "data": + if isinstance(mean, str): + if mean == "accurate": + mean, std = self.calculate_accurate_transformation(method) + elif mean == "estimate": + mean, std = self.calculate_estimated_transformation(method) + else: + raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or" + f"be an array with already calculated means. Given was: {mean}") + elif scope == "station": + mean, std = None, None + else: + raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}") + transformation["method"] = method + transformation["mean"] = mean + transformation["std"] = std + return transformation + + def calculate_accurate_transformation(self, method): + tmp = [] + mean = None + std = None + for station in self.stations: + try: + data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, + **self.kwargs) + chunks = (1, 100, data.data.shape[2]) + tmp.append(da.from_array(data.data.data, chunks=chunks)) + except EmptyQueryResult: + continue + tmp = da.concatenate(tmp, axis=1) + if method in ["standardise", "centre"]: + mean = da.nanmean(tmp, axis=1).compute() + mean = xr.DataArray(mean.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"]) + if method == "standardise": + std = da.nanstd(tmp, axis=1).compute() + std = xr.DataArray(std.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"]) + else: + raise NotImplementedError + return mean, std + + def calculate_estimated_transformation(self, method): + data = [[]]*len(self.variables) + coords = {"variables": self.variables, "Stations": range(0)} + mean = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) + std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) + for station in self.stations: + try: + data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, + **self.kwargs) + data.transform("datetime", method=method) + mean = mean.combine_first(data.mean) + std = std.combine_first(data.std) + data.transform("datetime", method=method, inverse=True) + except EmptyQueryResult: + continue + return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None + + def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True, + save_local_tmp_storage: bool = True) -> DataPrep: """ Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and remove nans. :param key: station key to choose the data generator. - :param local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from - tmp pickle file to save computational time (but of course more disk space required). + :param load_local_tmp_storage: say if data should be processed from scratch or loaded as already processed data + from tmp pickle file to save computational time (but of course more disk space required). + :param save_local_tmp_storage: save processed data as temporal file locally (default True) :return: preprocessed data as a DataPrep instance """ station = self.get_station_key(key) try: - if not local_tmp_storage: + if not load_local_tmp_storage: raise FileNotFoundError data = self._load_pickle_data(station, self.variables) except FileNotFoundError: logging.info(f"load not pickle data for {station}") data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, **self.kwargs) + if self.transformation is not None: + data.transform("datetime", **helpers.dict_pop(self.transformation, "scope")) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) - data.transform("datetime", method=self.transform_method) data.make_history_window(self.interpolate_dim, self.window_history_size) data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) data.history_label_nan_remove(self.interpolate_dim) - self._save_pickle_data(data) + if save_local_tmp_storage: + self._save_pickle_data(data) return data def _save_pickle_data(self, data: Any): diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 5bca71f52c9f136b5910d4e080491e0ff86484ae..98b47a6df3825581564fa9aaef7be8698408760e 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -216,7 +216,7 @@ class DataPrep(object): self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method) self._transform_method = None - def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False) -> None: + def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean = None, std=None) -> None: """ This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0 and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale @@ -247,11 +247,19 @@ class DataPrep(object): else: raise NotImplementedError + def f_apply(data): + if method == "standardise": + return mean, std, statistics.standardise_apply(data, mean, std) + elif method == "centre": + return mean, None, statistics.centre_apply(data, mean) + else: + raise NotImplementedError + if not inverse: if self._transform_method is not None: raise AssertionError(f"Transform method is already set. Therefore, data was already transformed with " f"{self._transform_method}. Please perform inverse transformation of data first.") - self.mean, self.std, self.data = f(self.data) + self.mean, self.std, self.data = locals()["f" if mean is None else "f_apply"](self.data) self._transform_method = method else: self.inverse_transform() @@ -370,7 +378,7 @@ class DataPrep(object): :param coord: name of axis to slice :return: """ - return data.loc[{coord: slice(start, end)}] + return data.loc[{coord: slice(str(start), str(end))}] def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: """ @@ -387,8 +395,7 @@ class DataPrep(object): return data def get_transposed_history(self): - if self.history is not None: - return self.history.transpose("datetime", "window", "Stations", "variables") + return self.history.transpose("datetime", "window", "Stations", "variables") if __name__ == "__main__": diff --git a/src/helpers.py b/src/helpers.py index c33684508b7b36cbebc3cbf3ac826d1779f9df50..e1496c3b232db29194878892647c59581b6a70a3 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -190,3 +190,8 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce """ multiplier = 10. ** decimals return round_type(number * multiplier) / multiplier + + +def dict_pop(dict: Dict, pop_keys): + pop_keys = to_list(pop_keys) + return {k: v for k, v in dict.items() if k not in pop_keys} diff --git a/src/model_modules/inception_model.py b/src/model_modules/inception_model.py index baf1d0250a319008c573109da765613ca87b2c77..1cb7656335495f0261abb434e4a203cb4e63887e 100644 --- a/src/model_modules/inception_model.py +++ b/src/model_modules/inception_model.py @@ -193,6 +193,7 @@ class InceptionModelBase: self.number_of_blocks += 1 self.part_of_block = 0 tower_build = {} + block_name = f"Block_{self.number_of_blocks}" for part, part_settings in tower_conv_parts.items(): tower_build[part] = self.create_conv_tower(input_x, **part_settings, **kwargs) if 'max_pooling' in tower_pool_parts.keys(): @@ -205,7 +206,8 @@ class InceptionModelBase: tower_build['maxpool'] = self.create_pool_tower(input_x, **tower_pool_parts, **kwargs) tower_build['avgpool'] = self.create_pool_tower(input_x, **tower_pool_parts, **kwargs, max_pooling=False) - block = keras.layers.concatenate(list(tower_build.values()), axis=3) + block = keras.layers.concatenate(list(tower_build.values()), axis=3, + name=block_name+"_Co") return block diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 48f7c13e51622d7d52405b73c0a6f57537b5b476..4c3b8872575ea9929f1d4ba3f5a42e222ac2fff4 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -19,6 +19,7 @@ DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', 'pblheight': 'maximum'} +DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"} class ExperimentSetup(RunEnvironment): @@ -33,7 +34,7 @@ class ExperimentSetup(RunEnvironment): limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", - create_new_model=None): + create_new_model=None, permute_data_on_training=None, transformation=None): # create run framework super().__init__() @@ -45,6 +46,7 @@ class ExperimentSetup(RunEnvironment): trainable = True self._set_param("trainable", trainable, default=True) self._set_param("fraction_of_training", fraction_of_train, default=0.8) + self._set_param("permute_data", permute_data_on_training, default=False, scope="general.train") # set experiment name exp_date = self._get_parser_args(parser_args).get("experiment_date") @@ -77,6 +79,8 @@ class ExperimentSetup(RunEnvironment): self._set_param("window_history_size", window_history_size, default=13) self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="general.preprocessing") self._set_param("sampling", sampling) + self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) + self._set_param("transformation", None, scope="general.preprocessing") # target self._set_param("target_var", target_var, default="o3") diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 4660a8116b6d0b860a7d0d50b92cee5e0deb77d8..3263f5c4562eeac321c7ce621df551fdf6373ba0 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -12,7 +12,7 @@ from src.run_modules.run_environment import RunEnvironment DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", - "station_type", "overwrite_local_data", "start", "end", "sampling"] + "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation"] class PreProcessing(RunEnvironment): @@ -35,7 +35,8 @@ class PreProcessing(RunEnvironment): def _run(self): args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing") kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing") - valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False) + stations = self.data_store.get("stations", "general") + valid_stations = self.check_valid_stations(args, kwargs, stations, load_tmp=False, save_tmp=False) self.data_store.set("stations", valid_stations, "general") self.split_train_val_test() self.report_pre_processing() @@ -53,11 +54,19 @@ class PreProcessing(RunEnvironment): logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}" f"{self.data_store.get('generator', 'general.test')[0][1].shape}") - def split_train_val_test(self): + def split_train_val_test(self) -> None: + """ + Splits all subsets. Currently: train, val, test and train_val (actually this is only the merge of train and val, + but as an separate generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train + subset needs always to be executed at first, to set a proper transformation. + """ fraction_of_training = self.data_store.get("fraction_of_training", "general") stations = self.data_store.get("stations", "general") train_index, val_index, test_index, train_val_index = self.split_set_indices(len(stations), fraction_of_training) subset_names = ["train", "val", "test", "train_val"] + if subset_names[0] != "train": # pragma: no cover + raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset" + f"order was: {subset_names}.") for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names): self.create_set_split(ind, scope) @@ -79,7 +88,16 @@ class PreProcessing(RunEnvironment): train_val_index = slice(0, pos_test_split) return train_index, val_index, test_index, train_val_index - def create_set_split(self, index_list, set_name): + def create_set_split(self, index_list: slice, set_name) -> None: + """ + Create the subset for given split index and stores the DataGenerator with given set name in data store as + `generator`. Checks for all valid stations using the default (kw)args for given scope and creates the + DataGenerator for all valid stations. Also sets all transformation information, if subset is training set. Make + sure, that the train set is executed first, and all other subsets afterwards. + :param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True, + this list is ignored. + :param set_name: name to load/save all information from/to data store without the leading general prefix. + """ scope = f"general.{set_name}" args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope) @@ -89,14 +107,16 @@ class PreProcessing(RunEnvironment): else: set_stations = stations[index_list] logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") - set_stations = self.check_valid_stations(args, kwargs, set_stations) + set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False) self.data_store.set("stations", set_stations, scope) set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) data_set = DataGenerator(**set_args, **kwargs) self.data_store.set("generator", data_set, scope) + if set_name == "train": + self.data_store.set("transformation", data_set.transformation, "general") @staticmethod - def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True): + def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True): """ Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the loading time are logged in debug mode. @@ -118,7 +138,8 @@ class PreProcessing(RunEnvironment): t_inner.run() try: # (history, label) = data_gen[station] - data = data_gen.get_data_generator(key=station, local_tmp_storage=load_tmp) + data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp, + save_local_tmp_storage=save_tmp) valid_stations.append(station) logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') logging.debug(f"{station}: loading time = {t_inner}") diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 7a522af0298bcabee62579f68bd29ed123cac7b0..df60c4f2f8dff4a9acb82920ad3c1d203813033d 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -65,7 +65,8 @@ class Training(RunEnvironment): :param mode: name of set, should be from ["train", "val", "test"] """ gen = self.data_store.get("generator", f"general.{mode}") - setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size)) + permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False) + setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, permute_data=permute_data)) def set_generators(self) -> None: """ diff --git a/src/statistics.py b/src/statistics.py index e3481d0e0f0561ac8a903648a69e92c6d6acc40d..26b2be8854c51584f20b753717ea94cc12967369 100644 --- a/src/statistics.py +++ b/src/statistics.py @@ -15,11 +15,11 @@ Data = Union[xr.DataArray, pd.DataFrame] def apply_inverse_transformation(data, mean, std=None, method="standardise"): - if method == 'standardise': + if method == 'standardise': # pragma: no branch return standardise_inverse(data, mean, std) - elif method == 'centre': + elif method == 'centre': # pragma: no branch return centre_inverse(data, mean) - elif method == 'normalise': + elif method == 'normalise': # pragma: no cover # use min/max of data or given min/max raise NotImplementedError else: @@ -52,6 +52,17 @@ def standardise_inverse(data: Data, mean: Data, std: Data) -> Data: return data * std + mean +def standardise_apply(data: Data, mean: Data, std: Data) -> Data: + """ + This applies `standardise` on data using given mean and std. + :param data: + :param mean: + :param std: + :return: + """ + return (data - mean) / std + + def centre(data: Data, dim: Union[str, int]) -> Tuple[Data, None, Data]: """ This function centres a xarray.dataarray (along dim) or pandas.DataFrame (along axis) to mean=0 @@ -77,6 +88,17 @@ def centre_inverse(data: Data, mean: Data) -> Data: return data + mean +def centre_apply(data: Data, mean: Data) -> Data: + """ + This applies `centre` on data using given mean and std. + :param data: + :param mean: + :param std: + :return: + """ + return data - mean + + def mean_squared_error(a, b): return np.square(a - b).mean() diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index 4c6dbb1c38f2e4a49e53883fbe3cb33cb565118a..a26e76a0e7f3ef0f5cdbedc07d73a690116966c9 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -37,7 +37,7 @@ class TestDistributor: def test_init_defaults(self, distributor): assert distributor.batch_size == 256 - assert distributor.fit_call is True + assert distributor.do_data_permutation is False def test_get_model_rank(self, distributor, model_with_minor_branch): assert distributor._get_model_rank() == 1 @@ -73,3 +73,28 @@ class TestDistributor: d = Distributor(gen, model) expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256) assert len(d) == expected + + def test_permute_data_no_permutation(self, distributor): + x = np.array(range(20)).reshape(2, 10).T + y = np.array(range(10)).reshape(10, 1) + x_perm, y_perm = distributor._permute_data(x, y) + assert np.testing.assert_equal(x, x_perm) is None + assert np.testing.assert_equal(y, y_perm) is None + + def test_permute_data(self, distributor): + x = np.array(range(20)).reshape(2, 10).T + y = np.array(range(10)).reshape(10, 1) + distributor.do_data_permutation = True + x_perm, y_perm = distributor._permute_data(x, y) + assert x_perm[0, 0] == y_perm[0] + assert x_perm[0, 1] == y_perm[0] + 10 + assert x_perm[5, 0] == y_perm[5] + assert x_perm[5, 1] == y_perm[5] + 10 + assert x_perm[-1, 0] == y_perm[-1] + assert x_perm[-1, 1] == y_perm[-1] + 10 + # resort x_perm and compare if equal to x + x_perm.sort(axis=0) + y_perm.sort(axis=0) + assert np.testing.assert_equal(x, x_perm) is None + assert np.testing.assert_equal(y, y_perm) is None + diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 142acd166604951352ad6686548c2cb76f609ce0..9bf11154609afa9ada2b488455f7a341a41d21ae 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -1,12 +1,15 @@ import os +import operator as op import pytest import shutil import numpy as np +import xarray as xr import pickle from src.data_handling.data_generator import DataGenerator from src.data_handling.data_preparation import DataPrep +from src.join import EmptyQueryResult class TestDataGenerator: @@ -22,6 +25,56 @@ class TestDataGenerator: return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014) + @pytest.fixture + def gen_with_transformation(self): + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], + 'datetime', 'variables', 'o3', start=2010, end=2014, + transformation={"scope": "data", "mean": "estimate"}, + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + + @pytest.fixture + def gen_no_init(self): + generator = object.__new__(DataGenerator) + path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data')) + generator.data_path = path + if not os.path.exists(path): + os.makedirs(path) + generator.stations = ["DEBW107", "DEBW013", "DEBW001"] + generator.network = "AIRBASE" + generator.variables = ["temp", "o3"] + generator.station_type = "background" + generator.kwargs = {"start": 2010, "end": 2014, "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}} + return generator + + @pytest.fixture + def accurate_transformation(self, gen_no_init): + tmp = np.nan + for station in gen_no_init.stations: + try: + data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, + station_type=gen_no_init.station_type, **gen_no_init.kwargs) + tmp = data_prep.data.combine_first(tmp) + except EmptyQueryResult: + continue + mean_expected = tmp.mean(dim=["Stations", "datetime"]) + std_expected = tmp.std(dim=["Stations", "datetime"]) + return mean_expected, std_expected + + @pytest.fixture + def estimated_transformation(self, gen_no_init): + mean, std = None, None + for station in gen_no_init.stations: + try: + data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, + station_type=gen_no_init.station_type, **gen_no_init.kwargs) + mean = data_prep.data.mean(axis=1).combine_first(mean) + std = data_prep.data.std(axis=1).combine_first(std) + except EmptyQueryResult: + continue + mean_expected = mean.mean(axis=0) + std_expected = std.mean(axis=0) + return mean_expected, std_expected + class DummyDataPrep: def __init__(self, data): self.station = "DEBW107" @@ -41,7 +94,7 @@ class TestDataGenerator: assert gen.limit_nan_fill == 1 assert gen.window_history_size == 7 assert gen.window_lead_time == 4 - assert gen.transform_method == "standardise" + assert gen.transformation is None assert gen.kwargs == {"start": 2010, "end": 2014} def test_repr(self, gen): @@ -76,6 +129,72 @@ class TestDataGenerator: assert station[1].data.shape[-1] == gen.window_lead_time assert station[0].data.shape[1] == gen.window_history_size + 1 + def test_setup_transformation_no_transformation(self, gen_no_init): + assert gen_no_init.setup_transformation(None) is None + assert gen_no_init.setup_transformation({}) == {"method": "standardise", "mean": None, "std": None} + assert gen_no_init.setup_transformation({"scope": "station", "mean": "accurate"}) == \ + {"scope": "station", "method": "standardise", "mean": None, "std": None} + + def test_setup_transformation_calculate_statistics(self, gen_no_init): + transformation = {"scope": "data", "mean": "accurate"} + res_acc = gen_no_init.setup_transformation(transformation) + assert sorted(res_acc.keys()) == sorted(["scope", "mean", "std", "method"]) + assert isinstance(res_acc["mean"], xr.DataArray) + assert isinstance(res_acc["std"], xr.DataArray) + transformation["mean"] = "estimate" + res_est = gen_no_init.setup_transformation(transformation) + assert sorted(res_est.keys()) == sorted(["scope", "mean", "std", "method"]) + assert isinstance(res_est["mean"], xr.DataArray) + assert isinstance(res_est["std"], xr.DataArray) + assert np.testing.assert_array_compare(op.__ne__, res_est["std"].values, res_acc["std"].values) is None + + def test_setup_transformation_use_given_statistics(self, gen_no_init): + mean = xr.DataArray([30, 15], coords={"variables": ["o3", "temp"]}, dims=["variables"]) + transformation = {"scope": "data", "method": "centre", "mean": mean} + res = gen_no_init.setup_transformation(transformation) + assert np.testing.assert_equal(res["mean"].values, mean.values) is None + assert res["std"] is None + + def test_setup_transformation_errors(self, gen_no_init): + transformation = {"scope": "random", "mean": "accurate"} + with pytest.raises(ValueError): + gen_no_init.setup_transformation(transformation) + transformation = {"scope": "data", "mean": "fit"} + with pytest.raises(ValueError): + gen_no_init.setup_transformation(transformation) + + def test_calculate_accurate_transformation_standardise(self, gen_no_init, accurate_transformation): + mean_expected, std_expected = accurate_transformation + mean, std = gen_no_init.calculate_accurate_transformation("standardise") + assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None + assert np.testing.assert_almost_equal(std_expected.values, std.values) is None + + def test_calculate_accurate_transformation_centre(self, gen_no_init, accurate_transformation): + mean_expected, _ = accurate_transformation + mean, std = gen_no_init.calculate_accurate_transformation("centre") + assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None + assert std is None + + def test_calculate_accurate_transformation_all_others(self, gen_no_init): + with pytest.raises(NotImplementedError): + gen_no_init.calculate_accurate_transformation("normalise") + + def test_calculate_estimated_transformation_standardise(self, gen_no_init, estimated_transformation): + mean_expected, std_expected = estimated_transformation + mean, std = gen_no_init.calculate_estimated_transformation("standardise") + assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None + assert np.testing.assert_almost_equal(std_expected.values, std.values) is None + + def test_calculate_estimated_transformation_centre(self, gen_no_init, estimated_transformation): + mean_expected, _ = estimated_transformation + mean, std = gen_no_init.calculate_estimated_transformation("centre") + assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None + assert std is None + + def test_calculate_estimated_transformation_all_others(self, gen_no_init): + with pytest.raises(NotImplementedError): + gen_no_init.calculate_estimated_transformation("normalise") + def test_get_station_key(self, gen): gen.stations.append("DEBW108") f = gen.get_station_key @@ -104,7 +223,7 @@ class TestDataGenerator: if os.path.exists(file): os.remove(file) assert not os.path.exists(file) - assert isinstance(gen.get_data_generator("DEBW107", local_tmp_storage=False), DataPrep) + assert isinstance(gen.get_data_generator("DEBW107", load_local_tmp_storage=False), DataPrep) t = os.stat(file).st_ctime assert os.path.exists(file) assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) @@ -113,6 +232,12 @@ class TestDataGenerator: assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) assert os.stat(file).st_ctime > t + def test_get_data_generator_transform(self, gen_with_transformation): + gen = gen_with_transformation + data = gen.get_data_generator("DEBW107", load_local_tmp_storage=False, save_local_tmp_storage=False) + assert data._transform_method == "standardise" + assert data.mean is not None + def test_save_pickle_data(self, gen): file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}_2010_2014_.pickle") if os.path.exists(file): diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 72bacaf9cc1e5a9dc9736b8e8eb7161f35d8ea69..ac449c4dc6d4c83a457eccc93a766ec4f17f58c9 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -152,6 +152,26 @@ class TestDataPrep: assert isinstance(data.mean, xr.DataArray) assert isinstance(data.std, xr.DataArray) + def test_transform_standardise_apply(self, data): + assert data._transform_method is None + assert data.mean is None + assert data.std is None + data_mean_orig = data.data.mean('datetime').variable.values + data_std_orig = data.data.std('datetime').variable.values + mean_external = np.array([20, 12]) + std_external = np.array([15, 5]) + mean = xr.DataArray(mean_external, coords={"variables": ['o3', 'temp']}, dims=["variables"]) + std = xr.DataArray(std_external, coords={"variables": ['o3', 'temp']}, dims=["variables"]) + data.transform('datetime', mean=mean, std=std) + assert all(data.mean.values == mean_external) + assert all(data.std.values == std_external) + data_mean_transformed = data.data.mean('datetime').variable.values + data_std_transformed = data.data.std('datetime').variable.values + data_mean_expected = (data_mean_orig - mean_external) / std_external # mean scales as any other data + data_std_expected = data_std_orig / std_external # std scales by given std + assert np.testing.assert_almost_equal(data_mean_transformed, data_mean_expected) is None + assert np.testing.assert_almost_equal(data_std_transformed, data_std_expected) is None + @pytest.mark.parametrize('mean, std, method, msg', [(10, 3, 'standardise', ''), (6, None, 'standardise', 'std, '), (None, 3, 'standardise', 'mean, '), (19, None, 'centre', ''), (None, 2, 'centre', 'mean, '), (8, 2, 'centre', ''), @@ -168,12 +188,29 @@ class TestDataPrep: assert data._transform_method is None assert data.mean is None assert data.std is None - data_std_org = data.data.std('datetime'). variable.values + data_std_orig = data.data.std('datetime'). variable.values data.transform('datetime', 'centre') assert data._transform_method == 'centre' assert np.testing.assert_almost_equal(data.data.mean('datetime').variable.values, np.array([[0, 0]])) is None - assert np.testing.assert_almost_equal(data.data.std('datetime').variable.values, data_std_org) is None + assert np.testing.assert_almost_equal(data.data.std('datetime').variable.values, data_std_orig) is None + assert data.std is None + + def test_transform_centre_apply(self, data): + assert data._transform_method is None + assert data.mean is None + assert data.std is None + data_mean_orig = data.data.mean('datetime').variable.values + data_std_orig = data.data.std('datetime').variable.values + mean_external = np.array([20, 12]) + mean = xr.DataArray(mean_external, coords={"variables": ['o3', 'temp']}, dims=["variables"]) + data.transform('datetime', 'centre', mean=mean) + assert all(data.mean.values == mean_external) assert data.std is None + data_mean_transformed = data.data.mean('datetime').variable.values + data_std_transformed = data.data.std('datetime').variable.values + data_mean_expected = (data_mean_orig - mean_external) # mean scales as any other data + assert np.testing.assert_almost_equal(data_mean_transformed, data_mean_expected) is None + assert np.testing.assert_almost_equal(data_std_transformed, data_std_orig) is None @pytest.mark.parametrize('method', ['standardise', 'centre']) def test_transform_inverse(self, data, method): diff --git a/test/test_datastore.py b/test/test_datastore.py index 95a58deafc915dd6193960e77bb99cc8ab8d85cb..9fcb319f51954b365c59274a4a9744f093e155f1 100644 --- a/test/test_datastore.py +++ b/test/test_datastore.py @@ -30,6 +30,14 @@ class TestDataStoreByVariable: def ds(self): return DataStoreByVariable() + @pytest.fixture + def ds_with_content(self, ds): + ds.set("tester1", 1, "general") + ds.set("tester2", 11, "general") + ds.set("tester2", 10, "general.sub") + ds.set("tester3", 21, "general") + return ds + def test_put(self, ds): ds.set("number", 3, "general.subscope") assert ds._store["number"]["general.subscope"] == 3 @@ -131,15 +139,18 @@ class TestDataStoreByVariable: assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \ [("number", "general.sub.sub", "ABC"), ("number1", "general.sub", 22), ("number2", "general.sub.sub", 3)] - def test_create_args_dict(self, ds): - ds.set("tester1", 1, "general") - ds.set("tester2", 11, "general") - ds.set("tester2", 10, "general.sub") - ds.set("tester3", 21, "general") + def test_create_args_dict_default_scope(self, ds_with_content): args = ["tester1", "tester2", "tester3", "tester4"] - assert ds.create_args_dict(args) == {"tester1": 1, "tester2": 11, "tester3": 21} - assert ds.create_args_dict(args, "general.sub") == {"tester1": 1, "tester2": 10, "tester3": 21} - assert ds.create_args_dict(["notAvail", "alsonot"]) == {} + assert ds_with_content.create_args_dict(args) == {"tester1": 1, "tester2": 11, "tester3": 21} + + def test_create_args_dict_given_scope(self, ds_with_content): + args = ["tester1", "tester2", "tester3", "tester4"] + assert ds_with_content.create_args_dict(args, "general.sub") == {"tester1": 1, "tester2": 10, "tester3": 21} + + def test_create_args_dict_missing_entry(self, ds_with_content): + args = ["tester1", "notAvail", "tester4"] + assert ds_with_content.create_args_dict(["notAvail", "alsonot"]) == {} + assert ds_with_content.create_args_dict(args) == {"tester1": 1} def test_set_args_from_dict(self, ds): ds.set_args_from_dict({"tester1": 1, "tester2": 10, "tester3": 21}) @@ -157,6 +168,14 @@ class TestDataStoreByScope: def ds(self): return DataStoreByScope() + @pytest.fixture + def ds_with_content(self, ds): + ds.set("tester1", 1, "general") + ds.set("tester2", 11, "general") + ds.set("tester2", 10, "general.sub") + ds.set("tester3", 21, "general") + return ds + def test_put_with_scope(self, ds): ds.set("number", 3, "general.subscope") assert ds._store["general.subscope"]["number"] == 3 @@ -258,15 +277,18 @@ class TestDataStoreByScope: assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \ [("number", "general.sub.sub", "ABC"), ("number1", "general.sub", 22), ("number2", "general.sub.sub", 3)] - def test_create_args_dict(self, ds): - ds.set("tester1", 1, "general") - ds.set("tester2", 11, "general") - ds.set("tester2", 10, "general.sub") - ds.set("tester3", 21, "general") + def test_create_args_dict_default_scope(self, ds_with_content): args = ["tester1", "tester2", "tester3", "tester4"] - assert ds.create_args_dict(args) == {"tester1": 1, "tester2": 11, "tester3": 21} - assert ds.create_args_dict(args, "general.sub") == {"tester1": 1, "tester2": 10, "tester3": 21} - assert ds.create_args_dict(["notAvail", "alsonot"]) == {} + assert ds_with_content.create_args_dict(args) == {"tester1": 1, "tester2": 11, "tester3": 21} + + def test_create_args_dict_given_scope(self, ds_with_content): + args = ["tester1", "tester2", "tester3", "tester4"] + assert ds_with_content.create_args_dict(args, "general.sub") == {"tester1": 1, "tester2": 10, "tester3": 21} + + def test_create_args_dict_missing_entry(self, ds_with_content): + args = ["tester1", "notAvail", "tester4"] + assert ds_with_content.create_args_dict(["notAvail", "alsonot"]) == {} + assert ds_with_content.create_args_dict(args) == {"tester1": 1} def test_set_args_from_dict(self, ds): ds.set_args_from_dict({"tester1": 1, "tester2": 10, "tester3": 21}) diff --git a/test/test_model_modules/test_inception_model.py b/test/test_model_modules/test_inception_model.py index 6847b24f738428550f4b59faf4c00f962b90208e..281ce1a315526df05918bc2b07918eff3c9f276d 100644 --- a/test/test_model_modules/test_inception_model.py +++ b/test/test_model_modules/test_inception_model.py @@ -237,7 +237,10 @@ class TestInceptionModelBase: assert isinstance(self.step_in(block_pool2._keras_history[0], depth=2), keras.layers.pooling.AveragePooling2D) assert self.step_in(block_pool2._keras_history[0], depth=3).name == 'Block_1d_Pad' assert isinstance(self.step_in(block_pool2._keras_history[0], depth=3), ReflectionPadding2D) - + # check naming of concat layer + assert block.name == 'Block_1_Co/concat:0' + assert block._keras_history[0].name == 'Block_1_Co' + assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate) # next block opts['input_x'] = block opts['tower_pool_parts']['max_pooling'] = True @@ -261,6 +264,10 @@ class TestInceptionModelBase: assert isinstance(self.step_in(block_pool._keras_history[0], depth=2), keras.layers.pooling.MaxPooling2D) assert self.step_in(block_pool._keras_history[0], depth=3).name == 'Block_2c_Pad' assert isinstance(self.step_in(block_pool._keras_history[0], depth=3), ReflectionPadding2D) + # check naming of concat layer + assert block.name == 'Block_2_Co/concat:0' + assert block._keras_history[0].name == 'Block_2_Co' + assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate) def test_inception_block_invalid_batchnorm(self, base, input_x): conv = {'tower_1': {'reduction_filter': 64, @@ -278,6 +285,7 @@ class TestInceptionModelBase: block = base.inception_block(**opts) assert "max_pooling has to be either a bool or empty. Given was: yes" in str(einfo.value) + def test_batch_normalisation(self, base, input_x): # import keras base.part_of_block += 1 diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index 9d1feb03ac71b980f6a4cd1b0e6cac2a52d9625b..29172a1b8500b605859e925574535c6158c7d805 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -54,7 +54,8 @@ class TestPreProcessing: assert obj_with_exp_setup.data_store.search_name("generator") == [] obj_with_exp_setup.split_train_val_test() data_store = obj_with_exp_setup.data_store - assert data_store.search_scope("general.train") == sorted(["generator", "start", "end", "stations"]) + expected_params = ["generator", "start", "end", "stations", "permute_data"] + assert data_store.search_scope("general.train") == sorted(expected_params) assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test", "general.train_val"]) @@ -100,13 +101,3 @@ class TestPreProcessing: assert dummy_list[val] == list(range(10, 13)) assert dummy_list[test] == list(range(13, 15)) assert dummy_list[train_val] == list(range(0, 13)) - - def test_create_args_dict_default_scope(self, obj_super_init): - assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2} - - def test_create_args_dict_given_scope(self, obj_super_init): - assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"], scope="general.sub") == {"NAME1": 10, "NAME2": 2} - - def test_create_args_dict_missing_entry(self, obj_super_init): - assert obj_super_init.data_store.create_args_dict(["NAME5", "NAME2"]) == {"NAME2": 2} - assert obj_super_init.data_store.create_args_dict(["NAME4", "NAME2"]) == {"NAME2": 2} diff --git a/test/test_statistics.py b/test/test_statistics.py index 308ac655787e69f90b45e65e7e7df8f35875f652..cad915564aac675cadda0f625dca1a073b2c8959 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -3,7 +3,10 @@ import pandas as pd import pytest import xarray as xr -from src.statistics import standardise, standardise_inverse, centre, centre_inverse +from src.statistics import standardise, standardise_inverse, standardise_apply, centre, centre_inverse, centre_apply,\ + apply_inverse_transformation + +lazy = pytest.lazy_fixture @pytest.fixture(scope='module') @@ -18,44 +21,95 @@ def pandas(input_data): return pd.DataFrame(input_data) +@pytest.fixture(scope='module') +def pd_mean(): + return [2, 10, 3] + + +@pytest.fixture(scope='module') +def pd_std(): + return [3, 2, 3] + + @pytest.fixture(scope='module') def xarray(input_data): - return xr.DataArray(input_data, dims=['index', 'value']) + shape = input_data.shape + coords = {'index': range(shape[0]), 'value': range(shape[1])} + return xr.DataArray(input_data, coords=coords, dims=coords.keys()) + + +@pytest.fixture(scope='module') +def xr_mean(input_data): + return xr.DataArray([2, 10, 3], coords={'value': range(3)}, dims=['value']) + + +@pytest.fixture(scope='module') +def xr_std(input_data): + return xr.DataArray([3, 2, 3], coords={'value': range(3)}, dims=['value']) class TestStandardise: - @pytest.mark.parametrize('data_org, dim', [(pytest.lazy_fixture('pandas'), 0), - (pytest.lazy_fixture('xarray'), 'index')]) - def test_standardise(self, data_org, dim): - mean, std, data = standardise(data_org, dim) + @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0), + (lazy('xarray'), 'index')]) + def test_standardise(self, data_orig, dim): + mean, std, data = standardise(data_orig, dim) assert np.testing.assert_almost_equal(mean, [2, -5, 10], decimal=1) is None assert np.testing.assert_almost_equal(std, [2, 3, 1], decimal=1) is None assert np.testing.assert_almost_equal(data.mean(dim), [0, 0, 0]) is None assert np.testing.assert_almost_equal(data.std(dim), [1, 1, 1]) is None - @pytest.mark.parametrize('data_org, dim', [(pytest.lazy_fixture('pandas'), 0), - (pytest.lazy_fixture('xarray'), 'index')]) - def test_standardise_inverse(self, data_org, dim): - mean, std, data = standardise(data_org, dim) + @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0), + (lazy('xarray'), 'index')]) + def test_standardise_inverse(self, data_orig, dim): + mean, std, data = standardise(data_orig, dim) data_recovered = standardise_inverse(data, mean, std) - assert np.testing.assert_array_almost_equal(data_org, data_recovered) is None + assert np.testing.assert_array_almost_equal(data_orig, data_recovered) is None + + @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0), + (lazy('xarray'), 'index')]) + def test_apply_standardise_inverse(self, data_orig, dim): + mean, std, data = standardise(data_orig, dim) + data_recovered = apply_inverse_transformation(data, mean, std) + assert np.testing.assert_array_almost_equal(data_orig, data_recovered) is None + + @pytest.mark.parametrize('data_orig, mean, std, dim', [(lazy('pandas'), lazy('pd_mean'), lazy('pd_std'), 0), + (lazy('xarray'), lazy('xr_mean'), lazy('xr_std'), 'index')]) + def test_standardise_apply(self, data_orig, mean, std, dim): + data = standardise_apply(data_orig, mean, std) + mean_expected = (np.array([2, -5, 10]) - np.array([2, 10, 3])) / np.array([3, 2, 3]) + std_expected = np.array([2, 3, 1]) / np.array([3, 2, 3]) + assert np.testing.assert_almost_equal(data.mean(dim), mean_expected, decimal=1) is None + assert np.testing.assert_almost_equal(data.std(dim), std_expected, decimal=1) is None class TestCentre: - @pytest.mark.parametrize('data_org, dim', [(pytest.lazy_fixture('pandas'), 0), - (pytest.lazy_fixture('xarray'), 'index')]) - def test_centre(self, data_org, dim): - mean, std, data = centre(data_org, dim) + @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0), + (lazy('xarray'), 'index')]) + def test_centre(self, data_orig, dim): + mean, std, data = centre(data_orig, dim) assert np.testing.assert_almost_equal(mean, [2, -5, 10], decimal=1) is None assert std is None assert np.testing.assert_almost_equal(data.mean(dim), [0, 0, 0]) is None - @pytest.mark.parametrize('data_org, dim', [(pytest.lazy_fixture('pandas'), 0), - (pytest.lazy_fixture('xarray'), 'index')]) - def test_centre_inverse(self, data_org, dim): - mean, _, data = centre(data_org, dim) + @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0), + (lazy('xarray'), 'index')]) + def test_centre_inverse(self, data_orig, dim): + mean, _, data = centre(data_orig, dim) data_recovered = centre_inverse(data, mean) - assert np.testing.assert_array_almost_equal(data_org, data_recovered) is None - + assert np.testing.assert_array_almost_equal(data_orig, data_recovered) is None + + @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0), + (lazy('xarray'), 'index')]) + def test_apply_centre_inverse(self, data_orig, dim): + mean, _, data = centre(data_orig, dim) + data_recovered = apply_inverse_transformation(data, mean, method="centre") + assert np.testing.assert_array_almost_equal(data_orig, data_recovered) is None + + @pytest.mark.parametrize('data_orig, mean, dim', [(lazy('pandas'), lazy('pd_mean'), 0), + (lazy('xarray'), lazy('xr_mean'), 'index')]) + def test_centre_apply(self, data_orig, mean, dim): + data = centre_apply(data_orig, mean) + mean_expected = np.array([2, -5, 10]) - np.array([2, 10, 3]) + assert np.testing.assert_almost_equal(data.mean(dim), mean_expected, decimal=1) is None