From 9562b7526a359857e5cb7eb38ed982f661409c1d Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 15 Oct 2020 13:47:08 +0200
Subject: [PATCH] corrected helper behaviour if single station is given

---
 mlair/helpers/helpers.py          | 25 +++++++++++++++----------
 test/test_helpers/test_helpers.py | 14 +++++++++++---
 2 files changed, 26 insertions(+), 13 deletions(-)

diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py
index b12d9028..3ecf1f62 100644
--- a/mlair/helpers/helpers.py
+++ b/mlair/helpers/helpers.py
@@ -32,16 +32,21 @@ def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray:
 
     :return: combined xarray
     """
-    xarray = None
-    for k, v in d.items():
-        if xarray is None:
-            xarray = v
-            xarray.coords[coordinate_name] = k
-        else:
-            tmp_xarray = v
-            tmp_xarray.coords[coordinate_name] = k
-            xarray = xr.concat([xarray, tmp_xarray], coordinate_name)
-    return xarray
+    if len(d.keys()) == 1:
+        k = list(d.keys())
+        xarray: xr.DataArray = d[k[0]]
+        return xarray.expand_dims(dim={coordinate_name: k}, axis=0)
+    else:
+        xarray = None
+        for k, v in d.items():
+            if xarray is None:
+                xarray = v
+                xarray.coords[coordinate_name] = k
+            else:
+                tmp_xarray = v
+                tmp_xarray.coords[coordinate_name] = k
+                xarray = xr.concat([xarray, tmp_xarray], coordinate_name)
+        return xarray
 
 
 def float_round(number: float, decimals: int = 0, round_type: Callable = math.ceil) -> float:
diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py
index 281d60e0..723b4a87 100644
--- a/test/test_helpers/test_helpers.py
+++ b/test/test_helpers/test_helpers.py
@@ -124,14 +124,22 @@ class TestPytestRegex:
 class TestDictToXarray:
 
     def test_dict_to_xarray(self):
-        array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]})
-        array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]})
+        array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
+        array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
         d = {"number1": array1, "number2": array2}
         res = dict_to_xarray(d, "merge_dim")
         assert type(res) == xr.DataArray
-        assert sorted(list(res.coords)) == ["merge_dim", "x"]
+        assert sorted(list(res.coords)) == ["merge_dim", "x", "y"]
         assert res.shape == (2, 2, 3)
 
+    def test_dict_to_xarray_single_entry(self):
+        array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
+        d = {"number1": array1}
+        res = dict_to_xarray(d, "merge_dim")
+        assert type(res) == xr.DataArray
+        assert sorted(list(res.coords)) == ["merge_dim", "x", "y"]
+        assert res.shape == (1, 2, 3)
+
 
 class TestFloatRound:
 
-- 
GitLab