diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index b12d9028747aa677802c4a99e35852b514128e4c..3ecf1f6213bf39d2e3571a1b451173b981a3dadf 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -32,16 +32,21 @@ def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray: :return: combined xarray """ - xarray = None - for k, v in d.items(): - if xarray is None: - xarray = v - xarray.coords[coordinate_name] = k - else: - tmp_xarray = v - tmp_xarray.coords[coordinate_name] = k - xarray = xr.concat([xarray, tmp_xarray], coordinate_name) - return xarray + if len(d.keys()) == 1: + k = list(d.keys()) + xarray: xr.DataArray = d[k[0]] + return xarray.expand_dims(dim={coordinate_name: k}, axis=0) + else: + xarray = None + for k, v in d.items(): + if xarray is None: + xarray = v + xarray.coords[coordinate_name] = k + else: + tmp_xarray = v + tmp_xarray.coords[coordinate_name] = k + xarray = xr.concat([xarray, tmp_xarray], coordinate_name) + return xarray def float_round(number: float, decimals: int = 0, round_type: Callable = math.ceil) -> float: diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py index 281d60e07463c6b5118f36714d80144443a03050..723b4a87d70453327ed6b7e355d3ef78a246652a 100644 --- a/test/test_helpers/test_helpers.py +++ b/test/test_helpers/test_helpers.py @@ -124,14 +124,22 @@ class TestPytestRegex: class TestDictToXarray: def test_dict_to_xarray(self): - array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) - array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) + array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]}) + array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]}) d = {"number1": array1, "number2": array2} res = dict_to_xarray(d, "merge_dim") assert type(res) == xr.DataArray - assert sorted(list(res.coords)) == ["merge_dim", "x"] + assert sorted(list(res.coords)) == ["merge_dim", "x", "y"] assert res.shape == (2, 2, 3) + def test_dict_to_xarray_single_entry(self): + array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]}) + d = {"number1": array1} + res = dict_to_xarray(d, "merge_dim") + assert type(res) == xr.DataArray + assert sorted(list(res.coords)) == ["merge_dim", "x", "y"] + assert res.shape == (1, 2, 3) + class TestFloatRound: