from mlair.helpers.testing import PyTestRegex, PyTestAllEqual

import re
import xarray as xr
import numpy as np

import pytest


class TestPyTestRegex:

    def test_init(self):
        test_regex = PyTestRegex(r"TestString\d+")
        assert isinstance(test_regex._regex, re._pattern_type)

    def test_eq(self):
        assert PyTestRegex(r"TestString\d*") == "TestString"
        assert PyTestRegex(r"TestString\d+") == "TestString9"
        assert "TestString4" == PyTestRegex(r"TestString\d+")

    def test_repr(self):
        assert repr(PyTestRegex(r"TestString\d+")) == r"TestString\d+"


class TestPyTestAllEqual:

    def test_numpy(self):
        assert PyTestAllEqual([np.array([1, 2, 3]), np.array([1, 2, 3]), np.array([1, 2, 3])])
        with pytest.raises(AssertionError):
            PyTestAllEqual([np.array([1, 2, 3]), np.array([2, 2, 3]), np.array([1, 2, 3])])

    def test_xarray(self):
        assert PyTestAllEqual([xr.DataArray([1, 2, 3]), xr.DataArray([1, 2, 3])])
        with pytest.raises(AssertionError):
            PyTestAllEqual([xr.DataArray([1, 2, 3]), xr.DataArray([1, 2, 3, 4])])

    def test_other(self):
        assert PyTestAllEqual(["test", "test", "test"])
        with pytest.raises(AssertionError):
            PyTestAllEqual(["test", "test", "tes2t"])

    def test_encapsulated(self):
        assert PyTestAllEqual([[np.array([1, 2, 3]), np.array([12, 22, 32])],
                               [np.array([1, 2, 3]), np.array([12, 22, 32])]])
        assert PyTestAllEqual([[xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])],
                               [xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])]])
        assert PyTestAllEqual([["test", "test2"],
                               ["test", "test2"]])