import pytest
import dask.array as da
import numpy as np
import xarray as xr

from mlair.helpers.geofunctions import deg2rad_all_points, haversine_dist


class TestDeg2RadAllPoints:

    @pytest.fixture
    def custom_np_arrays(self):
        return np.array([0., 0.]), np.array([30., 30.]), np.array([60., 60.]), np.array([90., 90.])

    @pytest.fixture
    def custom_da_arrays(self, custom_np_arrays):
        return (da.array(i) for i in custom_np_arrays)

    @pytest.fixture
    def custom_xr_arrays(self, custom_np_arrays):
        return (xr.DataArray(i) for i in custom_np_arrays)

    @pytest.mark.parametrize("value", ((0., 30., 60., 90.), (0, 30, 60, 90),
                                       pytest.lazy_fixture('custom_np_arrays')))
    def test_deg2rad_all_points_scalar_inputs(self, value):
        (lat1, lon1), (lat2, lon2) = deg2rad_all_points(*value)
        assert lat1 == pytest.approx(0.0)
        assert lon1 == pytest.approx(np.pi/6.)
        assert lat2 == pytest.approx(np.pi/3.)
        assert lon2 == pytest.approx(np.pi/2)

    def test_deg2rad_all_points_xr_arr_inputs(self, custom_xr_arrays):
        (lat1, lon1), (lat2, lon2) = deg2rad_all_points(*custom_xr_arrays)
        assert (lat1 == pytest.approx(0.0)).all()
        assert (lon1 == pytest.approx(np.pi / 6.)).all()
        assert (lat2 == pytest.approx(np.pi / 3.)).all()
        assert (lon2 == pytest.approx(np.pi / 2)).all()

    def test_deg2rad_all_points_xr_arr_inputs(self, custom_da_arrays):
        (lat1, lon1), (lat2, lon2) = deg2rad_all_points(*custom_da_arrays)
        assert lat1.compute() == pytest.approx(0.0)
        assert lon1.compute() == pytest.approx(np.pi / 6.)
        assert lat2.compute() == pytest.approx(np.pi / 3.)
        assert lon2.compute() == pytest.approx(np.pi / 2)


class TestHaversineDist:

    @pytest.mark.parametrize("lat1,lon1,lat2,lon2,to_radians,expected_dist",
                             ((90., 0., -90., 0., True, np.pi),
                              (np.pi / 2., 0., np.pi / -2., 0., False, np.pi),
                              (0., 0., 0., 180., True, np.pi),
                              (0., 0., 0., np.pi, False, np.pi),
                              (0., 0., -45., 0, True, np.pi / 4.),
                              (0., 0., np.pi / -4., 0, False, np.pi / 4.),
                              ))
    def test_haversine_dist_on_unit_sphere_scalars(self, lat1, lon1, lat2, lon2, to_radians, expected_dist):
        dist = haversine_dist(lat1=lat1, lon1=lon1, lat2=lat2, lon2=lon2, to_radians=to_radians, earth_radius=1.)
        assert dist == pytest.approx(expected_dist)

    @pytest.mark.parametrize("lat1,lon1,lat2,lon2,to_radians,expected_dist",
                             (
                              (xr.DataArray(np.array([90., 0, 45.]), dims='first', coords={'first': range(3)}),
                               xr.DataArray(np.array([0., 0, 0]), dims='first', coords={'first': range(3)}),
                               -90., 0., True,
                               np.array([[np.pi], [np.pi / 2.], [3./4.*np.pi]])),
                              (xr.DataArray(np.array([90., 0, 45.]), dims='first', coords={'first': range(3)}),
                               xr.DataArray(np.array([0., 0, 0]), dims='first', coords={'first': range(3)}),
                               np.array([-90., 0., 45.]),
                               np.array([0., 0., 0.]), True,
                               np.array([[np.pi, np.pi/2., np.pi/4.],
                                         [np.pi/2., 0., np.pi/4.],
                                         [3./4.*np.pi, np.pi/4., 0.]])),
                              ))
    def test_haversine_dist_on_unit_sphere_fields_and_scalars(self, lat1, lon1, lat2, lon2, to_radians, expected_dist):
        dist = haversine_dist(lat1=lat1, lon1=lon1, lat2=lat2, lon2=lon2, to_radians=to_radians, earth_radius=1.)
        assert (dist == pytest.approx(expected_dist)).all()

    @pytest.mark.parametrize("lat1,lon1,lat2,lon2,to_radians",
                             (
                                 (np.array([0., 0.]), 0., 0., 0., True),
                                 (0., np.array([0., 0.]), 0., 0., True),
                                 (0., 0., np.array([0., 0.]), 0., True),
                                 (0., 0., 0., np.array([0., 0.]), True),

                             ))
    def test_haversine_dist_on_unit_sphere_missmatch_dimensions(self, lat1, lon1, lat2, lon2, to_radians):
        with pytest.raises(AssertionError) as e:
            dist = haversine_dist(lat1=lat1, lon1=lon1, lat2=lat2, lon2=lon2, to_radians=to_radians, earth_radius=1.)