diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 52835975101f5ce6881b72b127e16c0e299dfb14..ddf276cf2d88c108d8622c507471f989c4f99e8b 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -13,6 +13,7 @@ from functools import reduce from typing import Tuple, Union, List import multiprocessing import psutil +import dask import numpy as np import xarray as xr @@ -83,11 +84,20 @@ class DefaultDataHandler(AbstractDataHandler): if store_processed_data is True: self._cleanup() if fresh_store is True else None data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme} + data = self._force_dask_computation(data) with open(self._save_file, "wb") as f: pickle.dump(data, f) logging.debug(f"save pickle data to {self._save_file}") self._reset_data() + @staticmethod + def _force_dask_computation(data): + try: + data = dask.compute(data)[0] + except: + pass + return data + def _load(self): try: with open(self._save_file, "rb") as f: