Skip to content
Snippets Groups Projects
Commit 1cdcd1e0 authored by leufen1's avatar leufen1
Browse files

transformation can now handle each given data handler

parent ebc64a21
Branches
No related tags found
7 merge requests!353add developments to release v1.5.0,!352Resolve "release v1.5.0",!343Update wrf with develop,!342Include sample-uncertainty to wrf workflow,!332Resolve "REFAC: individual trafo for unfiltered data",!331Resolve "REFAC: do not stop if filter plots fail",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
...@@ -53,6 +53,7 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -53,6 +53,7 @@ class DefaultDataHandler(AbstractDataHandler):
self._Y = None self._Y = None
self._X_extreme = None self._X_extreme = None
self._Y_extreme = None self._Y_extreme = None
self._data_intersection = None
self._use_multiprocessing = use_multiprocessing self._use_multiprocessing = use_multiprocessing
self._max_number_multiprocessing = max_number_multiprocessing self._max_number_multiprocessing = max_number_multiprocessing
_name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self)) _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self))
...@@ -172,9 +173,14 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -172,9 +173,14 @@ class DefaultDataHandler(AbstractDataHandler):
else: else:
X = list(map(lambda x: x.sel({dim: intersect}), X_original)) X = list(map(lambda x: x.sel({dim: intersect}), X_original))
Y = Y_original.sel({dim: intersect}) Y = Y_original.sel({dim: intersect})
self._data_intersection = intersect
self._X, self._Y = X, Y self._X, self._Y = X, Y
def get_observation(self): def get_observation(self):
dim = self.time_dim
if self._data_intersection is not None:
return self.id_class.observation.sel({dim: self._data_intersection}).copy().squeeze()
else:
return self.id_class.observation.copy().squeeze() return self.id_class.observation.copy().squeeze()
def apply_transformation(self, data, base="target", dim=0, inverse=False): def apply_transformation(self, data, base="target", dim=0, inverse=False):
...@@ -248,7 +254,7 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -248,7 +254,7 @@ class DefaultDataHandler(AbstractDataHandler):
d.coords[dim] = d.coords[dim].values + np.timedelta64(*timedelta) d.coords[dim] = d.coords[dim].values + np.timedelta64(*timedelta)
@classmethod @classmethod
def transformation(cls, set_stations, tmp_path=None, **kwargs): def transformation(cls, set_stations, tmp_path=None, dh_transformation=None, **kwargs):
""" """
### supported transformation methods ### supported transformation methods
...@@ -278,31 +284,14 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -278,31 +284,14 @@ class DefaultDataHandler(AbstractDataHandler):
If min and max are not None, the default data handler expects this parameters to match the data and applies If min and max are not None, the default data handler expects this parameters to match the data and applies
this values to the data. Make sure that all dimensions and/or coordinates are in agreement. this values to the data. Make sure that all dimensions and/or coordinates are in agreement.
""" """
if dh_transformation is None:
dh_transformation = cls.data_handler_transformation
sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} sp_keys = {k: copy.deepcopy(kwargs[k]) for k in dh_transformation.requirements() if k in kwargs}
if "transformation" not in sp_keys.keys(): if "transformation" not in sp_keys.keys():
return return
transformation_dict = ({}, {}) transformation_dict = ({}, {})
def _inner():
"""Inner method that is performed in both serial and parallel approach."""
if dh is not None:
for i, transformation in enumerate(dh._transformation):
for var in transformation.keys():
if var not in transformation_dict[i].keys():
transformation_dict[i][var] = {}
opts = transformation[var]
if not transformation_dict[i][var].get("method", opts["method"]) == opts["method"]:
# data handlers with filters are allowed to change transformation method to standardise
assert hasattr(dh, "filter_dim") and opts["method"] == "standardise"
transformation_dict[i][var]["method"] = opts["method"]
for k in ["mean", "std", "min", "max"]:
old = transformation_dict[i][var].get(k, None)
new = opts.get(k)
transformation_dict[i][var][k] = new if old is None else old.combine_first(new)
if "feature_range" in opts.keys():
transformation_dict[i][var]["feature_range"] = opts.get("feature_range", None)
max_process = kwargs.get("max_number_multiprocessing", 16) max_process = kwargs.get("max_number_multiprocessing", 16)
n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus
if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution
...@@ -311,24 +300,29 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -311,24 +300,29 @@ class DefaultDataHandler(AbstractDataHandler):
logging.info(f"running {getattr(pool, '_processes')} processes in parallel") logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
sp_keys.update({"tmp_path": tmp_path, "return_strategy": "reference"}) sp_keys.update({"tmp_path": tmp_path, "return_strategy": "reference"})
output = [ output = [
pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys) pool.apply_async(f_proc, args=(dh_transformation, station), kwds=sp_keys)
for station in set_stations] for station in set_stations]
for p in output: for p in output:
_res_file, s = p.get() _res_file, s = p.get()
with open(_res_file, "rb") as f: with open(_res_file, "rb") as f:
dh = dill.load(f) dh = dill.load(f)
os.remove(_res_file) os.remove(_res_file)
_inner() transformation_dict = cls.update_transformation_dict(dh, transformation_dict)
pool.close() pool.close()
else: # serial solution else: # serial solution
logging.info("use serial transformation approach") logging.info("use serial transformation approach")
sp_keys.update({"return_strategy": "result"}) sp_keys.update({"return_strategy": "result"})
for station in set_stations: for station in set_stations:
dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys) dh, s = f_proc(dh_transformation, station, **sp_keys)
_inner() transformation_dict = cls.update_transformation_dict(dh, transformation_dict)
# aggregate all information # aggregate all information
iter_dim = sp_keys.get("iter_dim", cls.DEFAULT_ITER_DIM) iter_dim = sp_keys.get("iter_dim", cls.DEFAULT_ITER_DIM)
transformation_dict = cls.aggregate_transformation(transformation_dict, iter_dim)
return transformation_dict
@classmethod
def aggregate_transformation(cls, transformation_dict, iter_dim):
pop_list = [] pop_list = []
for i, transformation in enumerate(transformation_dict): for i, transformation in enumerate(transformation_dict):
for k in transformation.keys(): for k in transformation.keys():
...@@ -349,6 +343,27 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -349,6 +343,27 @@ class DefaultDataHandler(AbstractDataHandler):
transformation_dict[i].pop(k) transformation_dict[i].pop(k)
return transformation_dict return transformation_dict
@classmethod
def update_transformation_dict(cls, dh, transformation_dict):
"""Inner method that is performed in both serial and parallel approach."""
if dh is not None:
for i, transformation in enumerate(dh._transformation):
for var in transformation.keys():
if var not in transformation_dict[i].keys():
transformation_dict[i][var] = {}
opts = transformation[var]
if not transformation_dict[i][var].get("method", opts["method"]) == opts["method"]:
# data handlers with filters are allowed to change transformation method to standardise
assert hasattr(dh, "filter_dim") and opts["method"] == "standardise"
transformation_dict[i][var]["method"] = opts["method"]
for k in ["mean", "std", "min", "max"]:
old = transformation_dict[i][var].get(k, None)
new = opts.get(k)
transformation_dict[i][var][k] = new if old is None else old.combine_first(new)
if "feature_range" in opts.keys():
transformation_dict[i][var]["feature_range"] = opts.get("feature_range", None)
return transformation_dict
def get_coordinates(self): def get_coordinates(self):
return self.id_class.get_coordinates() return self.id_class.get_coordinates()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment