Skip to content
Snippets Groups Projects
Commit d016cf66 authored by leufen1's avatar leufen1
Browse files

new tests for testing helpers, corrected wrong usage of PyTestAllEqual and...

new tests for testing helpers, corrected wrong usage of PyTestAllEqual and also refactored this method, /close #242
parent 4e9f96d0
Branches
Tags
3 merge requests!226Develop,!225Resolve "release v1.2.0",!220Resolve "new tests for helper testing"
Pipeline #55607 passed
...@@ -35,7 +35,8 @@ class PyTestRegex: ...@@ -35,7 +35,8 @@ class PyTestRegex:
return self._regex.pattern return self._regex.pattern
class PyTestAllEqual: def PyTestAllEqual(check_list: List):
class PyTestAllEqualClass:
""" """
Check if all elements in list are the same. Check if all elements in list are the same.
...@@ -47,11 +48,22 @@ class PyTestAllEqual: ...@@ -47,11 +48,22 @@ class PyTestAllEqual:
self._list = check_list self._list = check_list
self._test_function = None self._test_function = None
def _set_test_function(self): def _set_test_function(self, _list):
if isinstance(self._list[0], np.ndarray): if isinstance(_list[0], list):
_test_function = self._set_test_function(_list[0])
self._test_function = lambda r, s: all(map(lambda x, y: _test_function(x, y) is None, r, s))
elif isinstance(_list[0], np.ndarray):
self._test_function = np.testing.assert_array_equal self._test_function = np.testing.assert_array_equal
else: elif isinstance(_list[0], xr.DataArray):
self._test_function = xr.testing.assert_equal self._test_function = xr.testing.assert_equal
else:
self._test_function = lambda x, y: self._assert(x, y)
# raise TypeError(f"given type {type(_list[0])} is not supported by PyTestAllEqual.")
return self._test_function
@staticmethod
def _assert(x, y):
assert x == y
def _check_all_equal(self) -> bool: def _check_all_equal(self) -> bool:
""" """
...@@ -60,9 +72,9 @@ class PyTestAllEqual: ...@@ -60,9 +72,9 @@ class PyTestAllEqual:
:return boolean if elements are equal :return boolean if elements are equal
""" """
equal = True equal = True
self._set_test_function() self._set_test_function(self._list)
for b in self._list: for b in self._list:
equal *= self._test_function(self._list[0], b) is None equal *= self._test_function(self._list[0], b) in [None, True]
return bool(equal == 1) return bool(equal == 1)
def is_true(self) -> bool: def is_true(self) -> bool:
...@@ -73,16 +85,4 @@ class PyTestAllEqual: ...@@ -73,16 +85,4 @@ class PyTestAllEqual:
""" """
return self._check_all_equal() return self._check_all_equal()
return PyTestAllEqualClass(check_list).is_true()
def xr_all_equal(check_list: List) -> bool:
"""
Check if all given elements (preferably xarray's) in list are equal.
:param check_list: list with elements to check
:return: boolean if all elements are the same or not
"""
equal = True
for b in check_list:
equal *= xr.testing.assert_equal(check_list[0], b) is None
return equal == 1
\ No newline at end of file
...@@ -255,10 +255,6 @@ class TestKerasIterator: ...@@ -255,10 +255,6 @@ class TestKerasIterator:
expected = next(iter(collection)) expected = next(iter(collection))
assert PyTestAllEqual([X, expected.get_X()]) assert PyTestAllEqual([X, expected.get_X()])
assert PyTestAllEqual([Y, expected.get_Y()]) assert PyTestAllEqual([Y, expected.get_Y()])
reversed(iterator.indexes)
X, Y = iterator[3]
assert PyTestAllEqual([X, expected.get_X()])
assert PyTestAllEqual([Y, expected.get_Y()])
def test_on_epoch_end(self): def test_on_epoch_end(self):
iterator = object.__new__(KerasIterator) iterator = object.__new__(KerasIterator)
......
from mlair.helpers.testing import PyTestRegex, PyTestAllEqual
import re
import xarray as xr
import numpy as np
import pytest
class TestPyTestRegex:
def test_init(self):
test_regex = PyTestRegex(r"TestString\d+")
assert isinstance(test_regex._regex, re._pattern_type)
def test_eq(self):
assert PyTestRegex(r"TestString\d*") == "TestString"
assert PyTestRegex(r"TestString\d+") == "TestString9"
assert "TestString4" == PyTestRegex(r"TestString\d+")
def test_repr(self):
assert repr(PyTestRegex(r"TestString\d+")) == r"TestString\d+"
class TestPyTestAllEqual:
def test_numpy(self):
assert PyTestAllEqual([np.array([1, 2, 3]), np.array([1, 2, 3]), np.array([1, 2, 3])])
with pytest.raises(AssertionError):
PyTestAllEqual([np.array([1, 2, 3]), np.array([2, 2, 3]), np.array([1, 2, 3])])
def test_xarray(self):
assert PyTestAllEqual([xr.DataArray([1, 2, 3]), xr.DataArray([1, 2, 3])])
with pytest.raises(AssertionError):
PyTestAllEqual([xr.DataArray([1, 2, 3]), xr.DataArray([1, 2, 3, 4])])
def test_other(self):
assert PyTestAllEqual(["test", "test", "test"])
with pytest.raises(AssertionError):
PyTestAllEqual(["test", "test", "tes2t"])
def test_encapsulated(self):
assert PyTestAllEqual([[np.array([1, 2, 3]), np.array([12, 22, 32])],
[np.array([1, 2, 3]), np.array([12, 22, 32])]])
assert PyTestAllEqual([[xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])],
[xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])]])
assert PyTestAllEqual([["test", "test2"],
["test", "test2"]])
...@@ -361,8 +361,8 @@ class TestTrackPlot: ...@@ -361,8 +361,8 @@ class TestTrackPlot:
assert track_plot_obj.ax.lines[0]._linewidth == 2.5 assert track_plot_obj.ax.lines[0]._linewidth == 2.5
assert track_plot_obj.ax.lines[1]._color == "darkgrey" assert track_plot_obj.ax.lines[1]._color == "darkgrey"
assert track_plot_obj.ax.lines[1]._linewidth == 1.4 assert track_plot_obj.ax.lines[1]._linewidth == 1.4
assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x])
assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y])
def test_step(self, track_plot_obj): def test_step(self, track_plot_obj):
x_int, h, w = 0.5, 0.6, 0.65 x_int, h, w = 0.5, 0.6, 0.65
...@@ -379,8 +379,8 @@ class TestTrackPlot: ...@@ -379,8 +379,8 @@ class TestTrackPlot:
assert track_plot_obj.ax.lines[0]._linewidth == 2.5 assert track_plot_obj.ax.lines[0]._linewidth == 2.5
assert track_plot_obj.ax.lines[1]._color == "black" assert track_plot_obj.ax.lines[1]._color == "black"
assert track_plot_obj.ax.lines[1]._linewidth == 1.4 assert track_plot_obj.ax.lines[1]._linewidth == 1.4
assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x])
assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y])
def test_rect(self, track_plot_obj): def test_rect(self, track_plot_obj):
h, w = 0.5, 0.6 h, w = 0.5, 0.6
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment