Commit a8f12a62 authored by lukas leufen's avatar lukas leufen 👻
Browse files

new helpers to filter dicts

parent a58a6487
Pipeline #102335 failed with stages
in 11 minutes and 40 seconds
"""Collection of different supporting functions and classes."""
from .testing import PyTestRegex, PyTestAllEqual
from .testing import PyTestRegex, PyTestAllEqual, check_nested_equality
from .time_tracking import TimeTracking, TimeTrackingWrapper
from .logger import Logger
from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, make_keras_pickable, sort_like
from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, \
make_keras_pickable, sort_like, filter_dict_by_value
......@@ -176,16 +176,17 @@ def remove_items(obj: Union[List, Dict, Tuple], items: Any):
raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.")
def select_from_dict(dict_obj: dict, sel_list: Any, remove_none=False):
def select_from_dict(dict_obj: dict, sel_list: Any, remove_none: bool = False, filter_cond: bool = True) -> dict:
"""
Extract all key values pairs whose key is contained in the sel_list.
Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore the number of pairs in the
returned dict is always smaller or equal to the number of elements in the sel_list.
Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore, the number of pairs in the
returned dict is always smaller or equal to the number of elements in the sel_list. If `filter_cond` is given, this
method either return the parts of the input dictionary that are included or not in `sel_list`.
"""
sel_list = to_list(sel_list)
assert isinstance(dict_obj, dict)
sel_dict = {k: v for k, v in dict_obj.items() if k in sel_list}
sel_dict = {k: v for k, v in dict_obj.items() if (k in sel_list) is filter_cond}
sel_dict = sel_dict if not remove_none else {k: v for k, v in sel_dict.items() if v is not None}
return sel_dict
......@@ -252,6 +253,19 @@ def convert2xrda(arr: Union[xr.DataArray, xr.Dataset, np.ndarray, int, float],
return xr.DataArray(arr, **kwargs)
def filter_dict_by_value(dictionary: dict, filter_val: Any, filter_cond: bool) -> dict:
"""
Filter dictionary by its values.
:param dictionary: dict to filter
:param filter_val: search only for key value pair with a value equal to filter_val
:param filter_cond: indicate to use either all dict entries that fulfil the filter_val criteria (if `True`) or that
do not match the criteria (if `False`)
:returns: a filtered dict with either matching or non-matching elements depending on the `filter_cond`
"""
return dict(filter(lambda x: (x[1] == filter_val) is filter_cond, dictionary.items()))
# def convert_size(size_bytes):
# if size_bytes == 0:
# return "0B"
......
......@@ -12,8 +12,9 @@ import mock
import pytest
import string
from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict, sort_like
from mlair.helpers import PyTestRegex
from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict, \
sort_like, filter_dict_by_value
from mlair.helpers import PyTestRegex, check_nested_equality
from mlair.helpers import Logger, TimeTracking
from mlair.helpers.helpers import is_xarray, convert2xrda, relative_round
......@@ -223,6 +224,10 @@ class TestSelectFromDict:
assert select_from_dict(dictionary, ["a", "e"]) == {"a": 1, "e": None}
assert select_from_dict(dictionary, ["a", "e"], remove_none=True) == {"a": 1}
def test_select_condition(self, dictionary):
assert select_from_dict(dictionary, ["a", "e"], filter_cond=False) == {"b": 23, "c": "last"}
assert select_from_dict(dictionary, ["a", "c"], filter_cond=False, remove_none=True) == {"b": 23}
class TestRemoveItems:
......@@ -487,3 +492,22 @@ class TestSortLike:
l_obj = [1, 2, 3, 8, 4]
with pytest.raises(AssertionError) as e:
sort_like(l_obj, [1, 2, 3, 5, 6, 7, 8])
class TestFilterDictByValue:
def test_filter_dict_by_value(self):
data_origin = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA', 'temp': 'era5'}
expected = {'temp': 'era5'}
assert check_nested_equality(filter_dict_by_value(data_origin, "era", True), expected) is True
expected = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA'}
assert check_nested_equality(filter_dict_by_value(data_origin, "era", False), expected) is True
expected = {'o3': '', 'no': '', 'no2': ''}
assert check_nested_equality(filter_dict_by_value(data_origin, "", True), expected) is True
def test_filter_dict_by_value_not_avail(self):
data_origin = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA', 'temp': 'era5'}
expected = {}
assert check_nested_equality(filter_dict_by_value(data_origin, "not_avail", True), expected) is True
assert check_nested_equality(filter_dict_by_value(data_origin, "EA", True), expected) is True
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment