diff --git a/src/configuration/defaults.py b/src/configuration/defaults.py index 0038bb5512d602150905f6504bcd5e135b127382..4bb1ab2eef43ce2230fb2dfd3781322c9fc405cf 100644 --- a/src/configuration/defaults.py +++ b/src/configuration/defaults.py @@ -13,7 +13,8 @@ DEFAULT_START = "1997-01-01" DEFAULT_END = "2017-12-31" DEFAULT_WINDOW_HISTORY_SIZE = 13 DEFAULT_OVERWRITE_LOCAL_DATA = False -DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"} +# DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"} +DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise"} DEFAULT_HPC_LOGIN_LIST = ["ju", "hdfmll"] # ju[wels} #hdfmll(ogin) DEFAULT_HPC_HOST_LIST = ["jw", "hdfmlc"] # first part of node names for Juwels (jw[comp], hdfmlc(ompute). DEFAULT_CREATE_NEW_MODEL = True diff --git a/src/data_handling/advanced_data_handling.py b/src/data_handling/advanced_data_handling.py index e5c5214de9fe03f9e6c5de6e2ddcfbfb9987d052..f48bbb22cb5a52df540bf76b517f38e7b062511b 100644 --- a/src/data_handling/advanced_data_handling.py +++ b/src/data_handling/advanced_data_handling.py @@ -17,6 +17,7 @@ from typing import Union, List, Tuple import logging from functools import reduce from src.data_handling.data_preparation import StationPrep +from src.helpers.join import EmptyQueryResult number = Union[float, int] @@ -68,6 +69,10 @@ class AbstractDataPreparation: def own_args(cls, *args): return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args)) + @classmethod + def transformation(cls, *args, **kwargs): + raise NotImplementedError + def get_X(self, upsampling=False, as_numpy=False): raise NotImplementedError @@ -254,6 +259,34 @@ class DefaultDataPreparation(AbstractDataPreparation): for d in data: d.coords[dim].values += np.timedelta64(*timedelta) + @classmethod + def transformation(cls, set_stations, **kwargs): + sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs} + transformation_dict = sp_keys.pop("transformation") + if transformation_dict is None: + return + + scope = transformation_dict.pop("scope") + method = transformation_dict.pop("method") + if transformation_dict.pop("mean", None) is not None: + return + + mean, std = None, None + for station in set_stations: + try: + sp = StationPrep(station, transformation={"method": method}, **sp_keys) + mean = sp.mean.copy(deep=True) if mean is None else mean.combine_first(sp.mean) + std = sp.std.copy(deep=True) if std is None else std.combine_first(sp.std) + except (AttributeError, EmptyQueryResult): + continue + if mean is None: + return None + mean_estimated = mean.mean("Stations") + std_estimated = std.mean("Stations") + return {"scope": scope, "method": method, "mean": mean_estimated, "std": std_estimated} + + + def run_data_prep(): diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index c6ea67b87fc33a0952a5123754ab3fea62eee488..491b1530de3f935d5b8409e7f260da3276bc6aad 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -257,6 +257,10 @@ class PreProcessing(RunEnvironment): """ t_outer = TimeTracking() logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}") + # calculate transformation using train data + if set_name == "train": + self.transformation(data_preparation, set_stations) + # start station check collection = DataCollection() valid_stations = [] kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope=set_name) @@ -271,3 +275,12 @@ class PreProcessing(RunEnvironment): f"{len(set_stations)} valid stations.") return collection, valid_stations + def transformation(self, data_preparation, stations): + if hasattr(data_preparation, "transformation"): + kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope="train") + transformation_dict = data_preparation.transformation(stations, **kwargs) + if transformation_dict is not None: + self.data_store.set("transformation", transformation_dict) + + +