From a8f12a620b4646c55ca57ef4e3a1c9304b2de4f9 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 9 Jun 2022 11:45:51 +0200
Subject: [PATCH] new helpers to filter dicts

---
 mlair/helpers/__init__.py         |  5 +++--
 mlair/helpers/helpers.py          | 22 ++++++++++++++++++----
 test/test_helpers/test_helpers.py | 28 ++++++++++++++++++++++++++--
 3 files changed, 47 insertions(+), 8 deletions(-)

diff --git a/mlair/helpers/__init__.py b/mlair/helpers/__init__.py
index 3a5b8699..cf50fa05 100644
--- a/mlair/helpers/__init__.py
+++ b/mlair/helpers/__init__.py
@@ -1,6 +1,7 @@
 """Collection of different supporting functions and classes."""
 
-from .testing import PyTestRegex, PyTestAllEqual
+from .testing import PyTestRegex, PyTestAllEqual, check_nested_equality
 from .time_tracking import TimeTracking, TimeTrackingWrapper
 from .logger import Logger
-from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, make_keras_pickable, sort_like
+from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, \
+    make_keras_pickable, sort_like, filter_dict_by_value
diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py
index b583cf7d..f69e5b20 100644
--- a/mlair/helpers/helpers.py
+++ b/mlair/helpers/helpers.py
@@ -176,16 +176,17 @@ def remove_items(obj: Union[List, Dict, Tuple], items: Any):
         raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.")
 
 
-def select_from_dict(dict_obj: dict, sel_list: Any, remove_none=False):
+def select_from_dict(dict_obj: dict, sel_list: Any, remove_none: bool = False, filter_cond: bool = True) -> dict:
     """
     Extract all key values pairs whose key is contained in the sel_list.
 
-    Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore the number of pairs in the
-    returned dict is always smaller or equal to the number of elements in the sel_list.
+    Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore, the number of pairs in the
+    returned dict is always smaller or equal to the number of elements in the sel_list. If `filter_cond` is given, this
+    method either return the parts of the input dictionary that are included or not in `sel_list`.
     """
     sel_list = to_list(sel_list)
     assert isinstance(dict_obj, dict)
-    sel_dict = {k: v for k, v in dict_obj.items() if k in sel_list}
+    sel_dict = {k: v for k, v in dict_obj.items() if (k in sel_list) is filter_cond}
     sel_dict = sel_dict if not remove_none else {k: v for k, v in sel_dict.items() if v is not None}
     return sel_dict
 
@@ -252,6 +253,19 @@ def convert2xrda(arr: Union[xr.DataArray, xr.Dataset, np.ndarray, int, float],
         return xr.DataArray(arr, **kwargs)
 
 
+def filter_dict_by_value(dictionary: dict, filter_val: Any, filter_cond: bool) -> dict:
+    """
+    Filter dictionary by its values.
+
+    :param dictionary: dict to filter
+    :param filter_val: search only for key value pair with a value equal to filter_val
+    :param filter_cond: indicate to use either all dict entries that fulfil the filter_val criteria (if `True`) or that
+        do not match the criteria (if `False`)
+    :returns: a filtered dict with either matching or non-matching elements depending on the `filter_cond`
+    """
+    return dict(filter(lambda x: (x[1] == filter_val) is filter_cond, dictionary.items()))
+
+
 # def convert_size(size_bytes):
 #     if size_bytes == 0:
 #         return "0B"
diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py
index 70640be9..87c0f9ec 100644
--- a/test/test_helpers/test_helpers.py
+++ b/test/test_helpers/test_helpers.py
@@ -12,8 +12,9 @@ import mock
 import pytest
 import string
 
-from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict, sort_like
-from mlair.helpers import PyTestRegex
+from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict, \
+    sort_like, filter_dict_by_value
+from mlair.helpers import PyTestRegex, check_nested_equality
 from mlair.helpers import Logger, TimeTracking
 from mlair.helpers.helpers import is_xarray, convert2xrda, relative_round
 
@@ -223,6 +224,10 @@ class TestSelectFromDict:
         assert select_from_dict(dictionary, ["a", "e"]) == {"a": 1, "e": None}
         assert select_from_dict(dictionary, ["a", "e"], remove_none=True) == {"a": 1}
 
+    def test_select_condition(self, dictionary):
+        assert select_from_dict(dictionary, ["a", "e"], filter_cond=False) == {"b": 23, "c": "last"}
+        assert select_from_dict(dictionary, ["a", "c"], filter_cond=False, remove_none=True) == {"b": 23}
+
 
 class TestRemoveItems:
 
@@ -487,3 +492,22 @@ class TestSortLike:
         l_obj = [1, 2, 3, 8, 4]
         with pytest.raises(AssertionError) as e:
             sort_like(l_obj, [1, 2, 3, 5, 6, 7, 8])
+
+
+class TestFilterDictByValue:
+
+    def test_filter_dict_by_value(self):
+        data_origin = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA', 'temp': 'era5'}
+        expected = {'temp': 'era5'}
+        assert check_nested_equality(filter_dict_by_value(data_origin, "era", True), expected) is True
+        expected = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA'}
+        assert check_nested_equality(filter_dict_by_value(data_origin, "era", False), expected) is True
+        expected = {'o3': '', 'no': '', 'no2': ''}
+        assert check_nested_equality(filter_dict_by_value(data_origin, "", True), expected) is True
+
+    def test_filter_dict_by_value_not_avail(self):
+        data_origin = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA', 'temp': 'era5'}
+        expected = {}
+        assert check_nested_equality(filter_dict_by_value(data_origin, "not_avail", True), expected) is True
+        assert check_nested_equality(filter_dict_by_value(data_origin, "EA", True), expected) is True
+
-- 
GitLab