Skip to content
Snippets Groups Projects
Commit 080c8029 authored by lukas leufen's avatar lukas leufen
Browse files

transformation is included in the preprocessing stage

parent cdf911a5
No related branches found
No related tags found
4 merge requests!136update release branch,!135Release v0.11.0,!134MLAir is decoupled from join,!119Resolve "Include advanced data handling in workflow"
Pipeline #41375 failed
...@@ -13,7 +13,8 @@ DEFAULT_START = "1997-01-01" ...@@ -13,7 +13,8 @@ DEFAULT_START = "1997-01-01"
DEFAULT_END = "2017-12-31" DEFAULT_END = "2017-12-31"
DEFAULT_WINDOW_HISTORY_SIZE = 13 DEFAULT_WINDOW_HISTORY_SIZE = 13
DEFAULT_OVERWRITE_LOCAL_DATA = False DEFAULT_OVERWRITE_LOCAL_DATA = False
DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"} # DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"}
DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise"}
DEFAULT_HPC_LOGIN_LIST = ["ju", "hdfmll"] # ju[wels} #hdfmll(ogin) DEFAULT_HPC_LOGIN_LIST = ["ju", "hdfmll"] # ju[wels} #hdfmll(ogin)
DEFAULT_HPC_HOST_LIST = ["jw", "hdfmlc"] # first part of node names for Juwels (jw[comp], hdfmlc(ompute). DEFAULT_HPC_HOST_LIST = ["jw", "hdfmlc"] # first part of node names for Juwels (jw[comp], hdfmlc(ompute).
DEFAULT_CREATE_NEW_MODEL = True DEFAULT_CREATE_NEW_MODEL = True
......
...@@ -17,6 +17,7 @@ from typing import Union, List, Tuple ...@@ -17,6 +17,7 @@ from typing import Union, List, Tuple
import logging import logging
from functools import reduce from functools import reduce
from src.data_handling.data_preparation import StationPrep from src.data_handling.data_preparation import StationPrep
from src.helpers.join import EmptyQueryResult
number = Union[float, int] number = Union[float, int]
...@@ -68,6 +69,10 @@ class AbstractDataPreparation: ...@@ -68,6 +69,10 @@ class AbstractDataPreparation:
def own_args(cls, *args): def own_args(cls, *args):
return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args)) return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args))
@classmethod
def transformation(cls, *args, **kwargs):
raise NotImplementedError
def get_X(self, upsampling=False, as_numpy=False): def get_X(self, upsampling=False, as_numpy=False):
raise NotImplementedError raise NotImplementedError
...@@ -254,6 +259,34 @@ class DefaultDataPreparation(AbstractDataPreparation): ...@@ -254,6 +259,34 @@ class DefaultDataPreparation(AbstractDataPreparation):
for d in data: for d in data:
d.coords[dim].values += np.timedelta64(*timedelta) d.coords[dim].values += np.timedelta64(*timedelta)
@classmethod
def transformation(cls, set_stations, **kwargs):
sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs}
transformation_dict = sp_keys.pop("transformation")
if transformation_dict is None:
return
scope = transformation_dict.pop("scope")
method = transformation_dict.pop("method")
if transformation_dict.pop("mean", None) is not None:
return
mean, std = None, None
for station in set_stations:
try:
sp = StationPrep(station, transformation={"method": method}, **sp_keys)
mean = sp.mean.copy(deep=True) if mean is None else mean.combine_first(sp.mean)
std = sp.std.copy(deep=True) if std is None else std.combine_first(sp.std)
except (AttributeError, EmptyQueryResult):
continue
if mean is None:
return None
mean_estimated = mean.mean("Stations")
std_estimated = std.mean("Stations")
return {"scope": scope, "method": method, "mean": mean_estimated, "std": std_estimated}
def run_data_prep(): def run_data_prep():
......
...@@ -257,6 +257,10 @@ class PreProcessing(RunEnvironment): ...@@ -257,6 +257,10 @@ class PreProcessing(RunEnvironment):
""" """
t_outer = TimeTracking() t_outer = TimeTracking()
logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}") logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}")
# calculate transformation using train data
if set_name == "train":
self.transformation(data_preparation, set_stations)
# start station check
collection = DataCollection() collection = DataCollection()
valid_stations = [] valid_stations = []
kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope=set_name) kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope=set_name)
...@@ -271,3 +275,12 @@ class PreProcessing(RunEnvironment): ...@@ -271,3 +275,12 @@ class PreProcessing(RunEnvironment):
f"{len(set_stations)} valid stations.") f"{len(set_stations)} valid stations.")
return collection, valid_stations return collection, valid_stations
def transformation(self, data_preparation, stations):
if hasattr(data_preparation, "transformation"):
kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope="train")
transformation_dict = data_preparation.transformation(stations, **kwargs)
if transformation_dict is not None:
self.data_store.set("transformation", transformation_dict)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment