Skip to content
Snippets Groups Projects
Commit eec9315d authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue225_feat_parallel-transformation' into 'develop'

Resolve "parallel transformation"

See merge request toar/mlair!202
parents 5b91f2f9 7d3b5411
No related branches found
No related tags found
3 merge requests!226Develop,!225Resolve "release v1.2.0",!202Resolve "parallel transformation"
Pipeline #54112 passed
...@@ -11,6 +11,7 @@ import pickle ...@@ -11,6 +11,7 @@ import pickle
import shutil import shutil
from functools import reduce from functools import reduce
from typing import Tuple, Union, List from typing import Tuple, Union, List
import multiprocessing
import numpy as np import numpy as np
import xarray as xr import xarray as xr
...@@ -251,14 +252,28 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -251,14 +252,28 @@ class DefaultDataHandler(AbstractDataHandler):
return return
means = [None, None] means = [None, None]
stds = [None, None] stds = [None, None]
for station in set_stations:
try: if multiprocessing.cpu_count() > 1: # parallel solution
sp = cls.data_handler_transformation(station, **sp_keys) logging.info("use parallel transformation approach")
for i, data in enumerate([sp.input_data, sp.target_data]): pool = multiprocessing.Pool()
means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean) output = [
stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std) pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys)
except (AttributeError, EmptyQueryResult): for station in set_stations]
continue for p in output:
dh, s = p.get()
if dh is not None:
for i, data in enumerate([dh.input_data, dh.target_data]):
means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
else: # serial solution
logging.info("use serial transformation approach")
for station in set_stations:
dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys)
if dh is not None:
for i, data in enumerate([dh.input_data, dh.target_data]):
means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
if means[0] is None: if means[0] is None:
return None return None
transformation_class.inputs.mean = means[0].mean("Stations") transformation_class.inputs.mean = means[0].mean("Stations")
...@@ -268,4 +283,18 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -268,4 +283,18 @@ class DefaultDataHandler(AbstractDataHandler):
return transformation_class return transformation_class
def get_coordinates(self): def get_coordinates(self):
return self.id_class.get_coordinates() return self.id_class.get_coordinates()
\ No newline at end of file
def f_proc(data_handler, station, **sp_keys):
"""
Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and
therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and
the station that was used. This function must be implemented globally to work together with multiprocessing.
"""
try:
res = data_handler(station, **sp_keys)
except (AttributeError, EmptyQueryResult, KeyError, ValueError) as e:
logging.info(f"remove station {station} because it raised an error: {e}")
res = None
return res, station
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment