diff --git a/mlair/helpers/geofunctions.py b/mlair/helpers/geofunctions.py index d9443c32e234c3b405577ccfd188d3b2be662861..4bf1fcfc5d53fcd716c01d8fa1f311434f7dcf78 100644 --- a/mlair/helpers/geofunctions.py +++ b/mlair/helpers/geofunctions.py @@ -9,8 +9,10 @@ import xarray as xr from mlair.helpers.helpers import convert2xrda from typing import Union +xr_int_float = Union[xr.DataArray, xr.Dataset, np.ndarray, int, float] -def deg2rad_all_points(lat1, lat2, lon1, lon2): + +def deg2rad_all_points(lat1, lon1, lat2, lon2): """ Converts coordinates provided in lat1, lon1, lat2, and lon2 from deg to rad. In fact this method just calls dasks deg2rad method on all inputs and returns a tuple of tuples. @@ -30,8 +32,8 @@ def deg2rad_all_points(lat1, lat2, lon1, lon2): return (lat1, lon1), (lat2, lon2) -def haversine_dist(lat1: xr.DataArray, lon1: xr.DataArray, - lat2: Union[np.ndarray, xr.DataArray], lon2: Union[np.ndarray, xr.DataArray], +def haversine_dist(lat1: xr_int_float, lon1: xr_int_float, + lat2: xr_int_float, lon2: xr_int_float, to_radians: bool = True, earth_radius: float = 6371.,) -> xr.DataArray: """ Calculate the great circle distance between two points @@ -49,7 +51,7 @@ def haversine_dist(lat1: xr.DataArray, lon1: xr.DataArray, """ if to_radians: - (lat1, lon1), (lat2, lon2) = deg2rad_all_points(lat1, lat2, lon1, lon2) + (lat1, lon1), (lat2, lon2) = deg2rad_all_points(lat1, lon1, lat2, lon2) lat1 = convert2xrda(lat1, use_1d_default=True) lon1 = convert2xrda(lon1, use_1d_default=True) @@ -62,6 +64,7 @@ def haversine_dist(lat1: xr.DataArray, lon1: xr.DataArray, assert isinstance(lon1, xr.DataArray) assert isinstance(lat2, xr.DataArray) assert isinstance(lon2, xr.DataArray) + assert len(lat1.shape) >= len(lat2.shape) # broadcast lats and lons to calculate distances in a vectorized manner. lat1, lat2 = xr.broadcast(lat1, lat2) diff --git a/test/test_helpers/test_geofunctions.py b/test/test_helpers/test_geofunctions.py new file mode 100644 index 0000000000000000000000000000000000000000..1c096e06ac2d091de2a973d5013181849933e743 --- /dev/null +++ b/test/test_helpers/test_geofunctions.py @@ -0,0 +1,90 @@ +import pytest +import dask.array as da +import numpy as np +import xarray as xr + +from mlair.helpers.geofunctions import deg2rad_all_points, haversine_dist + + +class TestDeg2RadAllPoints: + + @pytest.fixture + def custom_np_arrays(self): + return np.array([0., 0.]), np.array([30., 30.]), np.array([60., 60.]), np.array([90., 90.]) + + @pytest.fixture + def custom_da_arrays(self, custom_np_arrays): + return (da.array(i) for i in custom_np_arrays) + + @pytest.fixture + def custom_xr_arrays(self, custom_np_arrays): + return (xr.DataArray(i) for i in custom_np_arrays) + + @pytest.mark.parametrize("value", ((0., 30., 60., 90.), (0, 30, 60, 90), + pytest.lazy_fixture('custom_np_arrays'))) + def test_deg2rad_all_points_scalar_inputs(self, value): + (lat1, lon1), (lat2, lon2) = deg2rad_all_points(*value) + assert lat1 == pytest.approx(0.0) + assert lon1 == pytest.approx(np.pi/6.) + assert lat2 == pytest.approx(np.pi/3.) + assert lon2 == pytest.approx(np.pi/2) + + def test_deg2rad_all_points_xr_arr_inputs(self, custom_xr_arrays): + (lat1, lon1), (lat2, lon2) = deg2rad_all_points(*custom_xr_arrays) + assert (lat1 == pytest.approx(0.0)).all() + assert (lon1 == pytest.approx(np.pi / 6.)).all() + assert (lat2 == pytest.approx(np.pi / 3.)).all() + assert (lon2 == pytest.approx(np.pi / 2)).all() + + def test_deg2rad_all_points_xr_arr_inputs(self, custom_da_arrays): + (lat1, lon1), (lat2, lon2) = deg2rad_all_points(*custom_da_arrays) + assert lat1.compute() == pytest.approx(0.0) + assert lon1.compute() == pytest.approx(np.pi / 6.) + assert lat2.compute() == pytest.approx(np.pi / 3.) + assert lon2.compute() == pytest.approx(np.pi / 2) + + +class TestHaversineDist: + + @pytest.mark.parametrize("lat1,lon1,lat2,lon2,to_radians,expected_dist", + ((90., 0., -90., 0., True, np.pi), + (np.pi / 2., 0., np.pi / -2., 0., False, np.pi), + (0., 0., 0., 180., True, np.pi), + (0., 0., 0., np.pi, False, np.pi), + (0., 0., -45., 0, True, np.pi / 4.), + (0., 0., np.pi / -4., 0, False, np.pi / 4.), + )) + def test_haversine_dist_on_unit_sphere_scalars(self, lat1, lon1, lat2, lon2, to_radians, expected_dist): + dist = haversine_dist(lat1=lat1, lon1=lon1, lat2=lat2, lon2=lon2, to_radians=to_radians, earth_radius=1.) + assert dist == pytest.approx(expected_dist) + + @pytest.mark.parametrize("lat1,lon1,lat2,lon2,to_radians,expected_dist", + ( + (xr.DataArray(np.array([90., 0, 45.]), dims='first', coords={'first': range(3)}), + xr.DataArray(np.array([0., 0, 0]), dims='first', coords={'first': range(3)}), + -90., 0., True, + np.array([[np.pi], [np.pi / 2.], [3./4.*np.pi]])), + (xr.DataArray(np.array([90., 0, 45.]), dims='first', coords={'first': range(3)}), + xr.DataArray(np.array([0., 0, 0]), dims='first', coords={'first': range(3)}), + np.array([-90., 0., 45.]), + np.array([0., 0., 0.]), True, + np.array([[np.pi, np.pi/2., np.pi/4.], + [np.pi/2., 0., np.pi/4.], + [3./4.*np.pi, np.pi/4., 0.]]) + ) + )) + def test_haversine_dist_on_unit_sphere_fields_and_scalars(self, lat1, lon1, lat2, lon2, to_radians, expected_dist): + dist = haversine_dist(lat1=lat1, lon1=lon1, lat2=lat2, lon2=lon2, to_radians=to_radians, earth_radius=1.) + assert (dist == pytest.approx(expected_dist)).all() + + @pytest.mark.parametrize("lat1,lon1,lat2,lon2,to_radians", + ( + (np.array([0., 0.]), 0., 0., 0., True), + (0., np.array([0., 0.]), 0., 0., True), + (0., 0.,np.array([0., 0.]), 0., True), + (0., 0., 0., np.array([0., 0.]), True), + + )) + def test_haversine_dist_on_unit_sphere_missmatch_dimensions(self, lat1, lon1, lat2, lon2, to_radians): + with pytest.raises(AssertionError) as e: + dist = haversine_dist(lat1=lat1, lon1=lon1, lat2=lat2, lon2=lon2, to_radians=to_radians, earth_radius=1.)