Skip to content
Snippets Groups Projects
Commit eec9315d authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue225_feat_parallel-transformation' into 'develop'

Resolve "parallel transformation"

See merge request toar/mlair!202
parents 5b91f2f9 7d3b5411
No related branches found
No related tags found
3 merge requests!226Develop,!225Resolve "release v1.2.0",!202Resolve "parallel transformation"
Pipeline #54112 passed
This commit is part of merge request !225. Comments created here will be created in the context of that merge request.
......@@ -11,6 +11,7 @@ import pickle
import shutil
from functools import reduce
from typing import Tuple, Union, List
import multiprocessing
import numpy as np
import xarray as xr
......@@ -251,14 +252,28 @@ class DefaultDataHandler(AbstractDataHandler):
return
means = [None, None]
stds = [None, None]
if multiprocessing.cpu_count() > 1: # parallel solution
logging.info("use parallel transformation approach")
pool = multiprocessing.Pool()
output = [
pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys)
for station in set_stations]
for p in output:
dh, s = p.get()
if dh is not None:
for i, data in enumerate([dh.input_data, dh.target_data]):
means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
else: # serial solution
logging.info("use serial transformation approach")
for station in set_stations:
try:
sp = cls.data_handler_transformation(station, **sp_keys)
for i, data in enumerate([sp.input_data, sp.target_data]):
dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys)
if dh is not None:
for i, data in enumerate([dh.input_data, dh.target_data]):
means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
except (AttributeError, EmptyQueryResult):
continue
if means[0] is None:
return None
transformation_class.inputs.mean = means[0].mean("Stations")
......@@ -269,3 +284,17 @@ class DefaultDataHandler(AbstractDataHandler):
def get_coordinates(self):
return self.id_class.get_coordinates()
def f_proc(data_handler, station, **sp_keys):
"""
Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and
therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and
the station that was used. This function must be implemented globally to work together with multiprocessing.
"""
try:
res = data_handler(station, **sp_keys)
except (AttributeError, EmptyQueryResult, KeyError, ValueError) as e:
logging.info(f"remove station {station} because it raised an error: {e}")
res = None
return res, station
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment