From 399a383ec7c1505a56ede98201c17ce96223faf3 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 9 Dec 2021 10:46:58 +0100
Subject: [PATCH] added test for test_nested_equality, /close #345

---
 mlair/helpers/testing.py                  |  5 +++-
 test/test_helpers/test_testing_helpers.py | 34 ++++++++++++++++++++++-
 2 files changed, 37 insertions(+), 2 deletions(-)

diff --git a/mlair/helpers/testing.py b/mlair/helpers/testing.py
index 8c3b301d..1fb8012f 100644
--- a/mlair/helpers/testing.py
+++ b/mlair/helpers/testing.py
@@ -96,7 +96,7 @@ def test_nested_equality(obj1, obj2):
 
         if isinstance(obj1, (tuple, list)):
             print(f"check length {len(obj1)} and {len(obj2)}")
-            assert len(obj1) == len(obj1)
+            assert len(obj1) == len(obj2)
             for pos in range(len(obj1)):
                 print(f"check pos {obj1[pos]} and {obj2[pos]}")
                 assert test_nested_equality(obj1[pos], obj2[pos]) is True
@@ -109,6 +109,9 @@ def test_nested_equality(obj1, obj2):
         elif isinstance(obj1, xr.DataArray):
             print(f"check xr {obj1} and {obj2}")
             assert xr.testing.assert_equal(obj1, obj2) is None
+        elif isinstance(obj1, np.ndarray):
+            print(f"check np {obj1} and {obj2}")
+            assert np.testing.assert_array_equal(obj1, obj2) is None
         else:
             print(f"check equal {obj1} and {obj2}")
             assert obj1 == obj2
diff --git a/test/test_helpers/test_testing_helpers.py b/test/test_helpers/test_testing_helpers.py
index 385161c7..83ba0101 100644
--- a/test/test_helpers/test_testing_helpers.py
+++ b/test/test_helpers/test_testing_helpers.py
@@ -1,4 +1,4 @@
-from mlair.helpers.testing import PyTestRegex, PyTestAllEqual
+from mlair.helpers.testing import PyTestRegex, PyTestAllEqual, test_nested_equality
 
 import re
 import xarray as xr
@@ -46,3 +46,35 @@ class TestPyTestAllEqual:
                                [xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])]])
         assert PyTestAllEqual([["test", "test2"],
                                ["test", "test2"]])
+
+
+class TestNestedEquality:
+
+    def test_nested_equality_single_entries(self):
+        assert test_nested_equality(3, 3) is True
+        assert test_nested_equality(3.9, 3.9) is True
+        assert test_nested_equality(3.91, 3.9) is False
+        assert test_nested_equality("3", 3) is False
+        assert test_nested_equality("3", "3") is True
+        assert test_nested_equality(None, None) is True
+
+    def test_nested_equality_xarray(self):
+        obj1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
+        obj2 = xr.ones_like(obj1) * obj1
+        assert test_nested_equality(obj1, obj2) is True
+
+    def test_nested_equality_numpy(self):
+        obj1 = np.random.randn(2, 3)
+        obj2 = obj1 * 1
+        assert test_nested_equality(obj1, obj2) is True
+
+    def test_nested_equality_list_tuple(self):
+        assert test_nested_equality([3, 3], [3, 3]) is True
+        assert test_nested_equality((2, 6), (2, 6)) is True
+        assert test_nested_equality([3, 3.5], [3.5, 3]) is False
+        assert test_nested_equality([3, 3.5, 10], [3, 3.5]) is False
+
+    def test_nested_equality_dict(self):
+        assert test_nested_equality({"a": 3, "b": 10}, {"b": 10, "a": 3}) is True
+        assert test_nested_equality({"a": 3, "b": [10, 100]}, {"b": [10, 100], "a": 3}) is True
+        assert test_nested_equality({"a": 3, "b": 10, "c": "c"}, {"b": 10, "a": 3}) is False
-- 
GitLab