diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 5bf6e90fba104b6c942f6c7a07071d89f90b5e52..ed541a66b992201cd46aac22e0df3267e786d693 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -10,7 +10,6 @@ import sys import traceback import pathos import multiprocess.context as ctx -ctx._force_start_method('spawn') import psutil from typing import Dict, Tuple, Union, List, Callable @@ -699,6 +698,15 @@ class PostProcessing(RunEnvironment): logging.info(f"start train_ols_model on train data") self.ols_model = OrdinaryLeastSquaredModel(self.train_data) + @staticmethod + def _get_ctx_context(): + _default_context = ctx._default_context + if _default_context._actual_context is None: + _actual_context = _default_context._default_context._name + else: + _actual_context = ctx._default_context._actual_context._name + return _actual_context + @TimeTrackingWrapper def make_prediction(self, subset): """ @@ -719,6 +727,8 @@ class PostProcessing(RunEnvironment): n_process = min([psutil.cpu_count(logical=False), len(subset), max_process]) # use only physical cpus if n_process > 1 and use_multiprocessing is True: # parallel solution logging.info("use parallel make prediction approach") + _actual_context = self._get_ctx_context() + ctx._force_start_method('spawn') pool = pathos.multiprocessing.ProcessingPool(n_process) logging.info(f"running {getattr(pool, 'ncpus')} processes in parallel") output = [ @@ -732,6 +742,7 @@ class PostProcessing(RunEnvironment): pool.close() pool.join() pool.clear() + ctx._force_start_method(_actual_context) else: # serial solution logging.info("use serial make prediction approach") for i, data in enumerate(subset):