From 9562b7526a359857e5cb7eb38ed982f661409c1d Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Thu, 15 Oct 2020 13:47:08 +0200 Subject: [PATCH] corrected helper behaviour if single station is given --- mlair/helpers/helpers.py | 25 +++++++++++++++---------- test/test_helpers/test_helpers.py | 14 +++++++++++--- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index b12d9028..3ecf1f62 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -32,16 +32,21 @@ def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray: :return: combined xarray """ - xarray = None - for k, v in d.items(): - if xarray is None: - xarray = v - xarray.coords[coordinate_name] = k - else: - tmp_xarray = v - tmp_xarray.coords[coordinate_name] = k - xarray = xr.concat([xarray, tmp_xarray], coordinate_name) - return xarray + if len(d.keys()) == 1: + k = list(d.keys()) + xarray: xr.DataArray = d[k[0]] + return xarray.expand_dims(dim={coordinate_name: k}, axis=0) + else: + xarray = None + for k, v in d.items(): + if xarray is None: + xarray = v + xarray.coords[coordinate_name] = k + else: + tmp_xarray = v + tmp_xarray.coords[coordinate_name] = k + xarray = xr.concat([xarray, tmp_xarray], coordinate_name) + return xarray def float_round(number: float, decimals: int = 0, round_type: Callable = math.ceil) -> float: diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py index 281d60e0..723b4a87 100644 --- a/test/test_helpers/test_helpers.py +++ b/test/test_helpers/test_helpers.py @@ -124,14 +124,22 @@ class TestPytestRegex: class TestDictToXarray: def test_dict_to_xarray(self): - array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) - array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) + array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]}) + array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]}) d = {"number1": array1, "number2": array2} res = dict_to_xarray(d, "merge_dim") assert type(res) == xr.DataArray - assert sorted(list(res.coords)) == ["merge_dim", "x"] + assert sorted(list(res.coords)) == ["merge_dim", "x", "y"] assert res.shape == (2, 2, 3) + def test_dict_to_xarray_single_entry(self): + array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]}) + d = {"number1": array1} + res = dict_to_xarray(d, "merge_dim") + assert type(res) == xr.DataArray + assert sorted(list(res.coords)) == ["merge_dim", "x", "y"] + assert res.shape == (1, 2, 3) + class TestFloatRound: -- GitLab