diff --git a/mlair/configuration/path_config.py b/mlair/configuration/path_config.py index e7418b984dab74b0527b8dca05a9f6c3636ac18f..6b9c799ceb190b9150be3a4cfcd336eaf45aa768 100644 --- a/mlair/configuration/path_config.py +++ b/mlair/configuration/path_config.py @@ -3,6 +3,7 @@ import getpass import logging import os import re +import shutil import socket from typing import Union @@ -112,17 +113,23 @@ def set_bootstrap_path(bootstrap_path: str, data_path: str) -> str: return os.path.abspath(bootstrap_path) -def check_path_and_create(path: str) -> None: +def check_path_and_create(path: str, remove_existing: bool = False) -> None: """ Check a given path and create if not existing. :param path: path to check and create + :param remove_existing: if set to true an existing folder is removed and replaced by a new one (default False). """ try: os.makedirs(path) logging.debug(f"Created path: {path}") except FileExistsError: - logging.debug(f"Path already exists: {path}") + if remove_existing is True: + logging.debug(f"Remove / clean path: {path}") + shutil.rmtree(path) + check_path_and_create(path, remove_existing=False) + else: + logging.debug(f"Path already exists: {path}") def get_host(): diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index 3392e41631fb395e95fc15b9199d81e1fd02121d..f749f53641755f204a739905a36418466e3f37d4 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -5,6 +5,8 @@ __date__ = '2020-07-20' import copy import datetime as dt +import gc + import dill import hashlib import logging @@ -107,6 +109,13 @@ class DataHandlerSingleStation(AbstractDataHandler): # create samples self.setup_samples() + self.clean_up() + + def clean_up(self): + self._data = None + self.input_data = None + self.target_data = None + gc.collect() def __str__(self): return self.station[0] @@ -253,7 +262,7 @@ class DataHandlerSingleStation(AbstractDataHandler): hash = self._get_hash() filename = os.path.join(self.lazy_path, hash + ".pickle") if not os.path.exists(filename): - dill.dump(self._create_lazy_data(), file=open(filename, "wb")) + dill.dump(self._create_lazy_data(), file=open(filename, "wb"), protocol=4) def _create_lazy_data(self): return [self._data, self.meta, self.input_data, self.target_data] diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py index e76f396aea80b2db76e01ea5baacf71d024b0d23..785eb7dffff28a676342feace519a6db0871c1df 100644 --- a/mlair/data_handler/data_handler_with_filter.py +++ b/mlair/data_handler/data_handler_with_filter.py @@ -393,7 +393,8 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation climate_filter.filtered_data] # create input data with filter index - input_data = xr.concat(climate_filter_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) + input_data = xr.concat(climate_filter_data, pd.Index(self.create_filter_index(add_unfiltered_index=False), + name=self.filter_dim)) # add unfiltered raw data if self._add_unfiltered is True: @@ -410,7 +411,7 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") - def create_filter_index(self) -> pd.Index: + def create_filter_index(self, add_unfiltered_index=True) -> pd.Index: """ Round cut off periods in days and append 'res' for residuum index. @@ -421,7 +422,7 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) index = list(map(f, index.tolist())) index = list(map(lambda x: str(x) + "d", index)) + ["res"] - if self._add_unfiltered: + if self._add_unfiltered and add_unfiltered_index: index.append("unfiltered") self.filter_dim_order = index return pd.Index(index, name=self.filter_dim) diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 35959582a9aa18ee412ed99a69bc75612641e42a..22a8ca4150c46ad1ce9ba0d005fa643e27906f53 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -8,6 +8,7 @@ import gc import logging import os import pickle +import random import dill import shutil from functools import reduce @@ -92,7 +93,7 @@ class DefaultDataHandler(AbstractDataHandler): data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme} data = self._force_dask_computation(data) with open(self._save_file, "wb") as f: - dill.dump(data, f) + dill.dump(data, f, protocol=4) logging.debug(f"save pickle data to {self._save_file}") self._reset_data() @@ -250,7 +251,7 @@ class DefaultDataHandler(AbstractDataHandler): d.coords[dim] = d.coords[dim].values + np.timedelta64(*timedelta) @classmethod - def transformation(cls, set_stations, **kwargs): + def transformation(cls, set_stations, tmp_path=None, **kwargs): """ ### supported transformation methods @@ -309,18 +310,22 @@ class DefaultDataHandler(AbstractDataHandler): n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution logging.info("use parallel transformation approach") - pool = multiprocessing.Pool( - min([psutil.cpu_count(logical=False), len(set_stations), 16])) # use only physical cpus + pool = multiprocessing.Pool(n_process) # use only physical cpus logging.info(f"running {getattr(pool, '_processes')} processes in parallel") + sp_keys.update({"tmp_path": tmp_path, "return_strategy": "reference"}) output = [ pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys) for station in set_stations] for p in output: - dh, s = p.get() + _res_file, s = p.get() + with open(_res_file, "rb") as f: + dh = dill.load(f) + os.remove(_res_file) _inner() pool.close() else: # serial solution logging.info("use serial transformation approach") + sp_keys.update({"return_strategy": "result"}) for station in set_stations: dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys) _inner() @@ -351,15 +356,22 @@ class DefaultDataHandler(AbstractDataHandler): return self.id_class.get_coordinates() -def f_proc(data_handler, station, **sp_keys): +def f_proc(data_handler, station, return_strategy="", tmp_path=None, **sp_keys): """ Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and the station that was used. This function must be implemented globally to work together with multiprocessing. """ + assert return_strategy in ["result", "reference"] try: res = data_handler(station, **sp_keys) except (AttributeError, EmptyQueryResult, KeyError, ValueError) as e: logging.info(f"remove station {station} because it raised an error: {e}") res = None - return res, station + if return_strategy == "result": + return res, station + else: + _tmp_file = os.path.join(tmp_path, f"{station}_{'%032x' % random.getrandbits(128)}.pickle") + with open(_tmp_file, "wb") as f: + dill.dump(res, f, protocol=4) + return _tmp_file, station diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index 5ddaa3ee3fe505eeb7c8082274d9cd888cec720f..4cc7310db32c2ef3bbdb9f70896a2f8455a974fc 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -4,6 +4,7 @@ __date__ = '2019-10-21' import inspect import math +import sys import numpy as np import xarray as xr @@ -179,3 +180,34 @@ def convert2xrda(arr: Union[xr.DataArray, xr.Dataset, np.ndarray, int, float], kwargs.update({'dims': dims, 'coords': coords}) return xr.DataArray(arr, **kwargs) + + +# def convert_size(size_bytes): +# if size_bytes == 0: +# return "0B" +# size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") +# i = int(math.floor(math.log(size_bytes, 1024))) +# p = math.pow(1024, i) +# s = round(size_bytes / p, 2) +# return "%s %s" % (s, size_name[i]) +# +# +# def get_size(obj, seen=None): +# """Recursively finds size of objects""" +# size = sys.getsizeof(obj) +# if seen is None: +# seen = set() +# obj_id = id(obj) +# if obj_id in seen: +# return 0 +# # Important mark as seen *before* entering recursion to gracefully handle +# # self-referential objects +# seen.add(obj_id) +# if isinstance(obj, dict): +# size += sum([get_size(v, seen) for v in obj.values()]) +# size += sum([get_size(k, seen) for k in obj.keys()]) +# elif hasattr(obj, '__dict__'): +# size += get_size(obj.__dict__, seen) +# elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): +# size += sum([get_size(i, seen) for i in obj]) +# return size diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py index a1e713a8c135800d02ff7c27894485a5da7fae37..fef52fb27d602b5931587ff0fa2d8edd7e0c2d8f 100644 --- a/mlair/helpers/statistics.py +++ b/mlair/helpers/statistics.py @@ -314,7 +314,10 @@ class SkillScores: :return: all CASES as well as all terms """ - ahead_names = list(self.external_data[self.ahead_dim].data) + if self.external_data is not None: + ahead_names = list(self.external_data[self.ahead_dim].data) + else: + ahead_names = list(internal_data[self.ahead_dim].data) all_terms = ['AI', 'AII', 'AIII', 'AIV', 'BI', 'BII', 'BIV', 'CI', 'CIV', 'CASE I', 'CASE II', 'CASE III', 'CASE IV'] diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 209859c1ff38efe2667c918aa5b79c96f2524be0..b11a2a33417cc27fa23122e31faa089a2dea321c 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -287,6 +287,10 @@ class ExperimentSetup(RunEnvironment): self._set_param("logging_path", None, os.path.join(experiment_path, "logging")) path_config.check_path_and_create(self.data_store.get("logging_path")) + # set tmp path + self._set_param("tmp_path", None, os.path.join(experiment_path, "tmp")) + path_config.check_path_and_create(self.data_store.get("tmp_path"), remove_existing=True) + # setup for data self._set_param("stations", stations, default=DEFAULT_STATIONS, apply=helpers.to_list) self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT) diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 08bff85c9c1fe06111ddb47e7a3952404e05c0ac..873919fa93af3e4a43c3b16c382d9746ec26a573 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -10,6 +10,8 @@ from typing import Tuple import multiprocessing import requests import psutil +import random +import dill import pandas as pd @@ -242,6 +244,7 @@ class PreProcessing(RunEnvironment): valid_stations = [] kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name) use_multiprocessing = self.data_store.get("use_multiprocessing") + tmp_path = self.data_store.get("tmp_path") max_process = self.data_store.get("max_number_multiprocessing") n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus @@ -249,18 +252,23 @@ class PreProcessing(RunEnvironment): logging.info("use parallel validate station approach") pool = multiprocessing.Pool(n_process) logging.info(f"running {getattr(pool, '_processes')} processes in parallel") + kwargs.update({"tmp_path": tmp_path, "return_strategy": "reference"}) output = [ pool.apply_async(f_proc, args=(data_handler, station, set_name, store_processed_data), kwds=kwargs) for station in set_stations] for i, p in enumerate(output): - dh, s = p.get() + _res_file, s = p.get() logging.info(f"...finished: {s} ({int((i + 1.) / len(output) * 100)}%)") + with open(_res_file, "rb") as f: + dh = dill.load(f) + os.remove(_res_file) if dh is not None: collection.add(dh) valid_stations.append(s) pool.close() else: # serial solution logging.info("use serial validate station approach") + kwargs.update({"return_strategy": "result"}) for station in set_stations: dh, s = f_proc(data_handler, station, set_name, store_processed_data, **kwargs) if dh is not None: @@ -268,7 +276,7 @@ class PreProcessing(RunEnvironment): valid_stations.append(s) logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/" - f"{len(set_stations)} valid stations.") + f"{len(set_stations)} valid stations ({set_name}).") if set_name == "train": self.store_data_handler_attributes(data_handler, collection) return collection, valid_stations @@ -288,7 +296,8 @@ class PreProcessing(RunEnvironment): def transformation(self, data_handler: AbstractDataHandler, stations): if hasattr(data_handler, "transformation"): kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train") - transformation_dict = data_handler.transformation(stations, **kwargs) + tmp_path = self.data_store.get_default("tmp_path", default=None) + transformation_dict = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs) if transformation_dict is not None: self.data_store.set("transformation", transformation_dict) @@ -312,12 +321,13 @@ class PreProcessing(RunEnvironment): logging.info("No preparation required because no competitor was provided to the workflow.") -def f_proc(data_handler, station, name_affix, store, **kwargs): +def f_proc(data_handler, station, name_affix, store, return_strategy="", tmp_path=None, **kwargs): """ Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and - therefore f_proc will return None as indication. On a successfull build, f_proc returns the built data handler and + therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and the station that was used. This function must be implemented globally to work together with multiprocessing. """ + assert return_strategy in ["result", "reference"] try: res = data_handler.build(station, name_affix=name_affix, store_processed_data=store, **kwargs) except (AttributeError, EmptyQueryResult, KeyError, requests.ConnectionError, ValueError, IndexError) as e: @@ -326,7 +336,15 @@ def f_proc(data_handler, station, name_affix, store, **kwargs): f"remove station {station} because it raised an error: {e} -> {' | '.join(f_inspect_error(formatted_lines))}") logging.debug(f"detailed information for removal of station {station}: {traceback.format_exc()}") res = None - return res, station + if return_strategy == "result": + return res, station + else: + if tmp_path is None: + tmp_path = os.getcwd() + _tmp_file = os.path.join(tmp_path, f"{station}_{'%032x' % random.getrandbits(128)}.pickle") + with open(_tmp_file, "wb") as f: + dill.dump(res, f, protocol=4) + return _tmp_file, station def f_inspect_error(formatted):