Skip to content
Snippets Groups Projects
Commit 1973d71e authored by Felix Kleinert's avatar Felix Kleinert
Browse files

add first tests for geofkts

parent 2dd6c146
No related branches found
No related tags found
1 merge request!231Draft: Resolve "Create WRF-Chem data handler"
Pipeline #60200 passed
......@@ -9,8 +9,10 @@ import xarray as xr
from mlair.helpers.helpers import convert2xrda
from typing import Union
xr_int_float = Union[xr.DataArray, xr.Dataset, np.ndarray, int, float]
def deg2rad_all_points(lat1, lat2, lon1, lon2):
def deg2rad_all_points(lat1, lon1, lat2, lon2):
"""
Converts coordinates provided in lat1, lon1, lat2, and lon2 from deg to rad. In fact this method just calls
dasks deg2rad method on all inputs and returns a tuple of tuples.
......@@ -30,8 +32,8 @@ def deg2rad_all_points(lat1, lat2, lon1, lon2):
return (lat1, lon1), (lat2, lon2)
def haversine_dist(lat1: xr.DataArray, lon1: xr.DataArray,
lat2: Union[np.ndarray, xr.DataArray], lon2: Union[np.ndarray, xr.DataArray],
def haversine_dist(lat1: xr_int_float, lon1: xr_int_float,
lat2: xr_int_float, lon2: xr_int_float,
to_radians: bool = True, earth_radius: float = 6371.,) -> xr.DataArray:
"""
Calculate the great circle distance between two points
......@@ -49,7 +51,7 @@ def haversine_dist(lat1: xr.DataArray, lon1: xr.DataArray,
"""
if to_radians:
(lat1, lon1), (lat2, lon2) = deg2rad_all_points(lat1, lat2, lon1, lon2)
(lat1, lon1), (lat2, lon2) = deg2rad_all_points(lat1, lon1, lat2, lon2)
lat1 = convert2xrda(lat1, use_1d_default=True)
lon1 = convert2xrda(lon1, use_1d_default=True)
......@@ -62,6 +64,7 @@ def haversine_dist(lat1: xr.DataArray, lon1: xr.DataArray,
assert isinstance(lon1, xr.DataArray)
assert isinstance(lat2, xr.DataArray)
assert isinstance(lon2, xr.DataArray)
assert len(lat1.shape) >= len(lat2.shape)
# broadcast lats and lons to calculate distances in a vectorized manner.
lat1, lat2 = xr.broadcast(lat1, lat2)
......
import pytest
import dask.array as da
import numpy as np
import xarray as xr
from mlair.helpers.geofunctions import deg2rad_all_points, haversine_dist
class TestDeg2RadAllPoints:
@pytest.fixture
def custom_np_arrays(self):
return np.array([0., 0.]), np.array([30., 30.]), np.array([60., 60.]), np.array([90., 90.])
@pytest.fixture
def custom_da_arrays(self, custom_np_arrays):
return (da.array(i) for i in custom_np_arrays)
@pytest.fixture
def custom_xr_arrays(self, custom_np_arrays):
return (xr.DataArray(i) for i in custom_np_arrays)
@pytest.mark.parametrize("value", ((0., 30., 60., 90.), (0, 30, 60, 90),
pytest.lazy_fixture('custom_np_arrays')))
def test_deg2rad_all_points_scalar_inputs(self, value):
(lat1, lon1), (lat2, lon2) = deg2rad_all_points(*value)
assert lat1 == pytest.approx(0.0)
assert lon1 == pytest.approx(np.pi/6.)
assert lat2 == pytest.approx(np.pi/3.)
assert lon2 == pytest.approx(np.pi/2)
def test_deg2rad_all_points_xr_arr_inputs(self, custom_xr_arrays):
(lat1, lon1), (lat2, lon2) = deg2rad_all_points(*custom_xr_arrays)
assert (lat1 == pytest.approx(0.0)).all()
assert (lon1 == pytest.approx(np.pi / 6.)).all()
assert (lat2 == pytest.approx(np.pi / 3.)).all()
assert (lon2 == pytest.approx(np.pi / 2)).all()
def test_deg2rad_all_points_xr_arr_inputs(self, custom_da_arrays):
(lat1, lon1), (lat2, lon2) = deg2rad_all_points(*custom_da_arrays)
assert lat1.compute() == pytest.approx(0.0)
assert lon1.compute() == pytest.approx(np.pi / 6.)
assert lat2.compute() == pytest.approx(np.pi / 3.)
assert lon2.compute() == pytest.approx(np.pi / 2)
class TestHaversineDist:
@pytest.mark.parametrize("lat1,lon1,lat2,lon2,to_radians,expected_dist",
((90., 0., -90., 0., True, np.pi),
(np.pi / 2., 0., np.pi / -2., 0., False, np.pi),
(0., 0., 0., 180., True, np.pi),
(0., 0., 0., np.pi, False, np.pi),
(0., 0., -45., 0, True, np.pi / 4.),
(0., 0., np.pi / -4., 0, False, np.pi / 4.),
))
def test_haversine_dist_on_unit_sphere_scalars(self, lat1, lon1, lat2, lon2, to_radians, expected_dist):
dist = haversine_dist(lat1=lat1, lon1=lon1, lat2=lat2, lon2=lon2, to_radians=to_radians, earth_radius=1.)
assert dist == pytest.approx(expected_dist)
@pytest.mark.parametrize("lat1,lon1,lat2,lon2,to_radians,expected_dist",
(
(xr.DataArray(np.array([90., 0, 45.]), dims='first', coords={'first': range(3)}),
xr.DataArray(np.array([0., 0, 0]), dims='first', coords={'first': range(3)}),
-90., 0., True,
np.array([[np.pi], [np.pi / 2.], [3./4.*np.pi]])),
(xr.DataArray(np.array([90., 0, 45.]), dims='first', coords={'first': range(3)}),
xr.DataArray(np.array([0., 0, 0]), dims='first', coords={'first': range(3)}),
np.array([-90., 0., 45.]),
np.array([0., 0., 0.]), True,
np.array([[np.pi, np.pi/2., np.pi/4.],
[np.pi/2., 0., np.pi/4.],
[3./4.*np.pi, np.pi/4., 0.]])
)
))
def test_haversine_dist_on_unit_sphere_fields_and_scalars(self, lat1, lon1, lat2, lon2, to_radians, expected_dist):
dist = haversine_dist(lat1=lat1, lon1=lon1, lat2=lat2, lon2=lon2, to_radians=to_radians, earth_radius=1.)
assert (dist == pytest.approx(expected_dist)).all()
@pytest.mark.parametrize("lat1,lon1,lat2,lon2,to_radians",
(
(np.array([0., 0.]), 0., 0., 0., True),
(0., np.array([0., 0.]), 0., 0., True),
(0., 0.,np.array([0., 0.]), 0., True),
(0., 0., 0., np.array([0., 0.]), True),
))
def test_haversine_dist_on_unit_sphere_missmatch_dimensions(self, lat1, lon1, lat2, lon2, to_radians):
with pytest.raises(AssertionError) as e:
dist = haversine_dist(lat1=lat1, lon1=lon1, lat2=lat2, lon2=lon2, to_radians=to_radians, earth_radius=1.)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment