diff --git a/mlair/helpers/__init__.py b/mlair/helpers/__init__.py index 3a5b8699a11ae39c0d3510a534db1dd144419d09..cf50fa05885d576bd64de67b83df3c8ed6d272e2 100644 --- a/mlair/helpers/__init__.py +++ b/mlair/helpers/__init__.py @@ -1,6 +1,7 @@ """Collection of different supporting functions and classes.""" -from .testing import PyTestRegex, PyTestAllEqual +from .testing import PyTestRegex, PyTestAllEqual, check_nested_equality from .time_tracking import TimeTracking, TimeTrackingWrapper from .logger import Logger -from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, make_keras_pickable, sort_like +from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, \ + make_keras_pickable, sort_like, filter_dict_by_value diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index b583cf7dc473db96181f88b0ab26e60ee225240d..f69e5b202cf3cdf35d4cf2f7767c8a2804e2da67 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -176,16 +176,17 @@ def remove_items(obj: Union[List, Dict, Tuple], items: Any): raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.") -def select_from_dict(dict_obj: dict, sel_list: Any, remove_none=False): +def select_from_dict(dict_obj: dict, sel_list: Any, remove_none: bool = False, filter_cond: bool = True) -> dict: """ Extract all key values pairs whose key is contained in the sel_list. - Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore the number of pairs in the - returned dict is always smaller or equal to the number of elements in the sel_list. + Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore, the number of pairs in the + returned dict is always smaller or equal to the number of elements in the sel_list. If `filter_cond` is given, this + method either return the parts of the input dictionary that are included or not in `sel_list`. """ sel_list = to_list(sel_list) assert isinstance(dict_obj, dict) - sel_dict = {k: v for k, v in dict_obj.items() if k in sel_list} + sel_dict = {k: v for k, v in dict_obj.items() if (k in sel_list) is filter_cond} sel_dict = sel_dict if not remove_none else {k: v for k, v in sel_dict.items() if v is not None} return sel_dict @@ -252,6 +253,19 @@ def convert2xrda(arr: Union[xr.DataArray, xr.Dataset, np.ndarray, int, float], return xr.DataArray(arr, **kwargs) +def filter_dict_by_value(dictionary: dict, filter_val: Any, filter_cond: bool) -> dict: + """ + Filter dictionary by its values. + + :param dictionary: dict to filter + :param filter_val: search only for key value pair with a value equal to filter_val + :param filter_cond: indicate to use either all dict entries that fulfil the filter_val criteria (if `True`) or that + do not match the criteria (if `False`) + :returns: a filtered dict with either matching or non-matching elements depending on the `filter_cond` + """ + return dict(filter(lambda x: (x[1] == filter_val) is filter_cond, dictionary.items())) + + # def convert_size(size_bytes): # if size_bytes == 0: # return "0B" diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py index 70640be9d56d71e4f68145b3bb68fb835e1e27a5..87c0f9ecb7f0a67267f0e24f9da035fc8315d56d 100644 --- a/test/test_helpers/test_helpers.py +++ b/test/test_helpers/test_helpers.py @@ -12,8 +12,9 @@ import mock import pytest import string -from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict, sort_like -from mlair.helpers import PyTestRegex +from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict, \ + sort_like, filter_dict_by_value +from mlair.helpers import PyTestRegex, check_nested_equality from mlair.helpers import Logger, TimeTracking from mlair.helpers.helpers import is_xarray, convert2xrda, relative_round @@ -223,6 +224,10 @@ class TestSelectFromDict: assert select_from_dict(dictionary, ["a", "e"]) == {"a": 1, "e": None} assert select_from_dict(dictionary, ["a", "e"], remove_none=True) == {"a": 1} + def test_select_condition(self, dictionary): + assert select_from_dict(dictionary, ["a", "e"], filter_cond=False) == {"b": 23, "c": "last"} + assert select_from_dict(dictionary, ["a", "c"], filter_cond=False, remove_none=True) == {"b": 23} + class TestRemoveItems: @@ -487,3 +492,22 @@ class TestSortLike: l_obj = [1, 2, 3, 8, 4] with pytest.raises(AssertionError) as e: sort_like(l_obj, [1, 2, 3, 5, 6, 7, 8]) + + +class TestFilterDictByValue: + + def test_filter_dict_by_value(self): + data_origin = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA', 'temp': 'era5'} + expected = {'temp': 'era5'} + assert check_nested_equality(filter_dict_by_value(data_origin, "era", True), expected) is True + expected = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA'} + assert check_nested_equality(filter_dict_by_value(data_origin, "era", False), expected) is True + expected = {'o3': '', 'no': '', 'no2': ''} + assert check_nested_equality(filter_dict_by_value(data_origin, "", True), expected) is True + + def test_filter_dict_by_value_not_avail(self): + data_origin = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA', 'temp': 'era5'} + expected = {} + assert check_nested_equality(filter_dict_by_value(data_origin, "not_avail", True), expected) is True + assert check_nested_equality(filter_dict_by_value(data_origin, "EA", True), expected) is True +