diff --git a/mlair/helpers/testing.py b/mlair/helpers/testing.py index 244eb69fdc46dcadaeb3ada5779f09d44aa83e2a..abb50883c7af49a0c1571d99f737e310abff9b13 100644 --- a/mlair/helpers/testing.py +++ b/mlair/helpers/testing.py @@ -35,54 +35,54 @@ class PyTestRegex: return self._regex.pattern -class PyTestAllEqual: - """ - Check if all elements in list are the same. - - :param check_list: list with elements to check - """ - - def __init__(self, check_list: List): - """Construct class.""" - self._list = check_list - self._test_function = None - - def _set_test_function(self): - if isinstance(self._list[0], np.ndarray): - self._test_function = np.testing.assert_array_equal - else: - self._test_function = xr.testing.assert_equal - - def _check_all_equal(self) -> bool: - """ - Check if all elements are equal. - - :return boolean if elements are equal +def PyTestAllEqual(check_list: List): + class PyTestAllEqualClass: """ - equal = True - self._set_test_function() - for b in self._list: - equal *= self._test_function(self._list[0], b) is None - return bool(equal == 1) + Check if all elements in list are the same. - def is_true(self) -> bool: + :param check_list: list with elements to check """ - Start equality check. - :return: true if equality test is passed, false otherwise - """ - return self._check_all_equal() - - -def xr_all_equal(check_list: List) -> bool: - """ - Check if all given elements (preferably xarray's) in list are equal. - - :param check_list: list with elements to check - - :return: boolean if all elements are the same or not - """ - equal = True - for b in check_list: - equal *= xr.testing.assert_equal(check_list[0], b) is None - return equal == 1 \ No newline at end of file + def __init__(self, check_list: List): + """Construct class.""" + self._list = check_list + self._test_function = None + + def _set_test_function(self, _list): + if isinstance(_list[0], list): + _test_function = self._set_test_function(_list[0]) + self._test_function = lambda r, s: all(map(lambda x, y: _test_function(x, y) is None, r, s)) + elif isinstance(_list[0], np.ndarray): + self._test_function = np.testing.assert_array_equal + elif isinstance(_list[0], xr.DataArray): + self._test_function = xr.testing.assert_equal + else: + self._test_function = lambda x, y: self._assert(x, y) + # raise TypeError(f"given type {type(_list[0])} is not supported by PyTestAllEqual.") + return self._test_function + + @staticmethod + def _assert(x, y): + assert x == y + + def _check_all_equal(self) -> bool: + """ + Check if all elements are equal. + + :return boolean if elements are equal + """ + equal = True + self._set_test_function(self._list) + for b in self._list: + equal *= self._test_function(self._list[0], b) in [None, True] + return bool(equal == 1) + + def is_true(self) -> bool: + """ + Start equality check. + + :return: true if equality test is passed, false otherwise + """ + return self._check_all_equal() + + return PyTestAllEqualClass(check_list).is_true() diff --git a/test/test_data_handler/test_iterator.py b/test/test_data_handler/test_iterator.py index 2bd33cc3aeea6bc631323e3d75d0011baacabad3..ade5c19215e61de5e209db900920187294ac9b18 100644 --- a/test/test_data_handler/test_iterator.py +++ b/test/test_data_handler/test_iterator.py @@ -255,10 +255,6 @@ class TestKerasIterator: expected = next(iter(collection)) assert PyTestAllEqual([X, expected.get_X()]) assert PyTestAllEqual([Y, expected.get_Y()]) - reversed(iterator.indexes) - X, Y = iterator[3] - assert PyTestAllEqual([X, expected.get_X()]) - assert PyTestAllEqual([Y, expected.get_Y()]) def test_on_epoch_end(self): iterator = object.__new__(KerasIterator) diff --git a/test/test_helpers/test_testing_helpers.py b/test/test_helpers/test_testing_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..385161c740f386847ef2f2dc4df17c1c84fa7fa5 --- /dev/null +++ b/test/test_helpers/test_testing_helpers.py @@ -0,0 +1,48 @@ +from mlair.helpers.testing import PyTestRegex, PyTestAllEqual + +import re +import xarray as xr +import numpy as np + +import pytest + + +class TestPyTestRegex: + + def test_init(self): + test_regex = PyTestRegex(r"TestString\d+") + assert isinstance(test_regex._regex, re._pattern_type) + + def test_eq(self): + assert PyTestRegex(r"TestString\d*") == "TestString" + assert PyTestRegex(r"TestString\d+") == "TestString9" + assert "TestString4" == PyTestRegex(r"TestString\d+") + + def test_repr(self): + assert repr(PyTestRegex(r"TestString\d+")) == r"TestString\d+" + + +class TestPyTestAllEqual: + + def test_numpy(self): + assert PyTestAllEqual([np.array([1, 2, 3]), np.array([1, 2, 3]), np.array([1, 2, 3])]) + with pytest.raises(AssertionError): + PyTestAllEqual([np.array([1, 2, 3]), np.array([2, 2, 3]), np.array([1, 2, 3])]) + + def test_xarray(self): + assert PyTestAllEqual([xr.DataArray([1, 2, 3]), xr.DataArray([1, 2, 3])]) + with pytest.raises(AssertionError): + PyTestAllEqual([xr.DataArray([1, 2, 3]), xr.DataArray([1, 2, 3, 4])]) + + def test_other(self): + assert PyTestAllEqual(["test", "test", "test"]) + with pytest.raises(AssertionError): + PyTestAllEqual(["test", "test", "tes2t"]) + + def test_encapsulated(self): + assert PyTestAllEqual([[np.array([1, 2, 3]), np.array([12, 22, 32])], + [np.array([1, 2, 3]), np.array([12, 22, 32])]]) + assert PyTestAllEqual([[xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])], + [xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])]]) + assert PyTestAllEqual([["test", "test2"], + ["test", "test2"]]) diff --git a/test/test_plotting/test_tracker_plot.py b/test/test_plotting/test_tracker_plot.py index 196879657452fe12238c990fc419cb0848c9ec9c..9587e71352dd4648009a72a7046c2b068dd5584d 100644 --- a/test/test_plotting/test_tracker_plot.py +++ b/test/test_plotting/test_tracker_plot.py @@ -356,13 +356,13 @@ class TestTrackPlot: assert len(track_plot_obj.ax.lines) == 0 track_plot_obj.line(start_x=5, end_x=6, y=2) assert len(track_plot_obj.ax.lines) == 2 - pos_x, pos_y = np.array([5 + w, 6]), np.ones((2, )) * (2 + h / 2) + pos_x, pos_y = np.array([5 + w, 6]), np.ones((2,)) * (2 + h / 2) assert track_plot_obj.ax.lines[0]._color == "white" assert track_plot_obj.ax.lines[0]._linewidth == 2.5 assert track_plot_obj.ax.lines[1]._color == "darkgrey" assert track_plot_obj.ax.lines[1]._linewidth == 1.4 - assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() - assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]) + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]) def test_step(self, track_plot_obj): x_int, h, w = 0.5, 0.6, 0.65 @@ -379,8 +379,8 @@ class TestTrackPlot: assert track_plot_obj.ax.lines[0]._linewidth == 2.5 assert track_plot_obj.ax.lines[1]._color == "black" assert track_plot_obj.ax.lines[1]._linewidth == 1.4 - assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() - assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]) + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]) def test_rect(self, track_plot_obj): h, w = 0.5, 0.6