Skip to content
Snippets Groups Projects
Commit 7d3b5411 authored by leufen1's avatar leufen1
Browse files

introduced parallel transformation for all data handlers inheriting from default data handler

parent 83fd35e2
No related branches found
No related tags found
3 merge requests!226Develop,!225Resolve "release v1.2.0",!202Resolve "parallel transformation"
Pipeline #54109 passed
This commit is part of merge request !225. Comments created here will be created in the context of that merge request.
......@@ -11,6 +11,7 @@ import pickle
import shutil
from functools import reduce
from typing import Tuple, Union, List
import multiprocessing
import numpy as np
import xarray as xr
......@@ -251,14 +252,28 @@ class DefaultDataHandler(AbstractDataHandler):
return
means = [None, None]
stds = [None, None]
if multiprocessing.cpu_count() > 1: # parallel solution
logging.info("use parallel transformation approach")
pool = multiprocessing.Pool()
output = [
pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys)
for station in set_stations]
for p in output:
dh, s = p.get()
if dh is not None:
for i, data in enumerate([dh.input_data, dh.target_data]):
means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
else: # serial solution
logging.info("use serial transformation approach")
for station in set_stations:
try:
sp = cls.data_handler_transformation(station, **sp_keys)
for i, data in enumerate([sp.input_data, sp.target_data]):
dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys)
if dh is not None:
for i, data in enumerate([dh.input_data, dh.target_data]):
means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
except (AttributeError, EmptyQueryResult):
continue
if means[0] is None:
return None
transformation_class.inputs.mean = means[0].mean("Stations")
......@@ -269,3 +284,17 @@ class DefaultDataHandler(AbstractDataHandler):
def get_coordinates(self):
return self.id_class.get_coordinates()
def f_proc(data_handler, station, **sp_keys):
"""
Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and
therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and
the station that was used. This function must be implemented globally to work together with multiprocessing.
"""
try:
res = data_handler(station, **sp_keys)
except (AttributeError, EmptyQueryResult, KeyError, ValueError) as e:
logging.info(f"remove station {station} because it raised an error: {e}")
res = None
return res, station
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment