diff --git a/mlair/helpers/__init__.py b/mlair/helpers/__init__.py index 9e2f612c86dc0477693567210493fbdcf3002954..4671334c16267be819ab8ee0ad96b7135ee01531 100644 --- a/mlair/helpers/__init__.py +++ b/mlair/helpers/__init__.py @@ -3,4 +3,4 @@ from .testing import PyTestRegex, PyTestAllEqual from .time_tracking import TimeTracking, TimeTrackingWrapper from .logger import Logger -from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value +from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index 3ecf1f6213bf39d2e3571a1b451173b981a3dadf..36470ebc1c3a008c0f6ecca11478d83d6fa57cec 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -99,6 +99,19 @@ def remove_items(obj: Union[List, Dict], items: Any): raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.") +def select_from_dict(dict_obj: dict, sel_list: str): + """ + Extract all key values pairs whose key is contained in the sel_list. + + Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore the number of pairs in the + returned dict is always smaller or equal to the number of elements in the sel_list. + """ + sel_list = to_list(sel_list) + assert isinstance(dict_obj, dict) + sel_dict = {k: v for k, v in dict_obj.items() if k in sel_list} + return sel_dict + + def extract_value(encapsulated_value): try: return extract_value(encapsulated_value[0])