diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 6fa7952d4bc0a278f17f767073969822924b6d5f..68aff5947743dfdc66f95d93d5b8b284a87789d8 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -309,6 +309,7 @@ class DefaultDataHandler(AbstractDataHandler): os.remove(_res_file) transformation_dict = cls.update_transformation_dict(dh, transformation_dict) pool.close() + pool.join() else: # serial solution logging.info("use serial transformation approach") sp_keys.update({"return_strategy": "result"}) diff --git a/mlair/helpers/time_tracking.py b/mlair/helpers/time_tracking.py index 3105ebcd04406b7d449ba312bd3af46f83e3a716..cf366db88adc524e90c2b771bef77c71ee5a9502 100644 --- a/mlair/helpers/time_tracking.py +++ b/mlair/helpers/time_tracking.py @@ -68,12 +68,13 @@ class TimeTracking(object): The only disadvantage of the latter implementation is, that the duration is logged but not returned. """ - def __init__(self, start=True, name="undefined job", logging_level=logging.INFO): + def __init__(self, start=True, name="undefined job", logging_level=logging.INFO, log_on_enter=False): """Construct time tracking and start if enabled.""" self.start = None self.end = None self._name = name self._logging = {logging.INFO: logging.info, logging.DEBUG: logging.debug}.get(logging_level, logging.info) + self._log_on_enter = log_on_enter if start: self._start() @@ -124,6 +125,7 @@ class TimeTracking(object): def __enter__(self): """Context manager.""" + self._logging(f"start {self._name}") if self._log_on_enter is True else None return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index 8d4ab2689b1eea24dc9d39d53b04e51405a3a874..47051a500c29349197f3163861a0fe40cade525d 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -711,6 +711,7 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover for i, p in enumerate(output): res.append(p.get()) pool.close() + pool.join() else: # serial solution for var in d[self.variables_dim].values: res.append(f_proc(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim))) @@ -735,6 +736,7 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover for i, p in enumerate(output): res.append(p.get()) pool.close() + pool.join() else: for g in generator: res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim, self.f_index, use_last_input_value)) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index dbffc5ca206e022afbc1729d3589f287ccebdc11..3f20d7b5cd8fa8d57c43f204b537ef02c08a8c95 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -18,7 +18,7 @@ import xarray as xr from mlair.configuration import path_config from mlair.data_handler import Bootstraps, KerasIterator from mlair.helpers.datastore import NameNotFoundInDataStore -from mlair.helpers import TimeTracking, statistics, extract_value, remove_items, to_list, tables +from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ @@ -114,17 +114,17 @@ class PostProcessing(RunEnvironment): # feature importance bootstraps if self.data_store.get("evaluate_feature_importance", "postprocessing"): - with TimeTracking(name="calculate feature importance using bootstraps"): + with TimeTracking(name="evaluate_feature_importance", log_on_enter=True): create_new_bootstraps = self.data_store.get("create_new_bootstraps", "feature_importance") bootstrap_method = self.data_store.get("bootstrap_method", "feature_importance") bootstrap_type = self.data_store.get("bootstrap_type", "feature_importance") self.calculate_feature_importance(create_new_bootstraps, bootstrap_type=bootstrap_type, bootstrap_method=bootstrap_method) - if self.feature_importance_skill_scores is not None: - self.report_feature_importance_results(self.feature_importance_skill_scores) + if self.feature_importance_skill_scores is not None: + self.report_feature_importance_results(self.feature_importance_skill_scores) # skill scores and error metrics - with TimeTracking(name="calculate skill scores"): + with TimeTracking(name="calculate_error_metrics", log_on_enter=True): skill_score_competitive, _, skill_score_climatological, errors = self.calculate_error_metrics() self.skill_scores = (skill_score_competitive, skill_score_climatological) self.report_error_metrics(errors) @@ -134,12 +134,14 @@ class PostProcessing(RunEnvironment): # plotting self.plot() + @TimeTrackingWrapper def estimate_sample_uncertainty(self, separate_ahead=False): """ Estimate sample uncertainty by using a bootstrap approach. Forecasts are split into individual blocks along time and randomly drawn with replacement. The resulting behaviour of the error indicates the robustness of each analyzed model to quantify which model might be superior compared to others. """ + logging.info("start estimate_sample_uncertainty") n_boots = self.data_store.get_default("n_boots", default=1000, scope="uncertainty_estimate") block_length = self.data_store.get_default("block_length", default="1m", scope="uncertainty_estimate") evaluate_competitors = self.data_store.get_default("evaluate_competitors", default=True, diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 873919fa93af3e4a43c3b16c382d9746ec26a573..116a37b305fbe0c2e81dd89bd8ba43257d29a61c 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -266,6 +266,7 @@ class PreProcessing(RunEnvironment): collection.add(dh) valid_stations.append(s) pool.close() + pool.join() else: # serial solution logging.info("use serial validate station approach") kwargs.update({"return_strategy": "result"})