From 194fd5de3dcd91ba6d9619c55000370bacd20c78 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Wed, 4 May 2022 16:56:40 +0200
Subject: [PATCH] adjusted helper function to work for single numbers with
 precision

---
 mlair/helpers/testing.py                  | 12 ++++++++++--
 test/test_helpers/test_testing_helpers.py |  6 +++---
 2 files changed, 13 insertions(+), 5 deletions(-)

diff --git a/mlair/helpers/testing.py b/mlair/helpers/testing.py
index 15c2d79c..eb8982ae 100644
--- a/mlair/helpers/testing.py
+++ b/mlair/helpers/testing.py
@@ -141,8 +141,16 @@ def check_nested_equality(obj1, obj2, precision=None):
                 print(f"check np {obj1} and {obj2} with precision {precision}")
                 assert np.testing.assert_array_almost_equal(obj1, obj2, decimal=precision) is None
         else:
-            print(f"check equal {obj1} and {obj2}")
-            assert obj1 == obj2
+            if isinstance(obj1, (int, float)) and isinstance(obj2, (int, float)):
+                if precision is None:
+                    print(f"check number equal {obj1} and {obj2}")
+                    assert np.testing.assert_equal(obj1, obj2) is None
+                else:
+                    print(f"check number equal {obj1} and {obj2} with precision {precision}")
+                    assert np.testing.assert_almost_equal(obj1, obj2, decimal=precision) is None
+            else:
+                print(f"check equal {obj1} and {obj2}")
+                assert obj1 == obj2
     except AssertionError:
         return False
     return True
diff --git a/test/test_helpers/test_testing_helpers.py b/test/test_helpers/test_testing_helpers.py
index c15a7ea9..8a4bdb92 100644
--- a/test/test_helpers/test_testing_helpers.py
+++ b/test/test_helpers/test_testing_helpers.py
@@ -58,14 +58,14 @@ class TestNestedEquality:
         assert check_nested_equality("3", 3) is False
         assert check_nested_equality("3", "3") is True
         assert check_nested_equality(None, None) is True
-        assert check_nested_equality(3.91, 3.9, 1) is True
-        assert check_nested_equality(3.91, 3.9, 2) is False
+        assert check_nested_equality(3.92, 3.9, 1) is True
+        assert check_nested_equality(3.92, 3.9, 2) is False
 
     def test_nested_equality_xarray(self):
         obj1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
         obj2 = xr.ones_like(obj1) * obj1
         assert check_nested_equality(obj1, obj2) is True
-        obj2 = 1.0001 * obj2
+        obj2 = obj2 * 1.0001
         assert check_nested_equality(obj1, obj2) is False
         assert check_nested_equality(obj1, obj2, 3) is True
 
-- 
GitLab