diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py
index d07d8cf1ca70ebdbd864cf58fb3b4a61ff183868..42b66dcb68b184112a321473e3aae250d697c452 100644
--- a/mlair/helpers/helpers.py
+++ b/mlair/helpers/helpers.py
@@ -18,7 +18,7 @@ def to_list(obj: Any) -> List:
:return: list containing obj, or obj itself (if obj was already a list)
"""
- if isinstance(obj, set):
+ if isinstance(obj, (set, tuple)):
obj = list(obj)
elif not isinstance(obj, list):
obj = [obj]
@@ -116,6 +116,12 @@ def select_from_dict(dict_obj: dict, sel_list: Any):
def extract_value(encapsulated_value):
try:
- return extract_value(encapsulated_value[0])
+ if isinstance(encapsulated_value, str):
+ raise TypeError
+ if len(encapsulated_value) == 1:
+ return extract_value(encapsulated_value[0])
+ else:
+ raise NotImplementedError("Trying to extract an encapsulated value from objects with more than a single "
+ "entry is not supported by this function.")
except TypeError:
return encapsulated_value
diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py
index 723b4a87d70453327ed6b7e355d3ef78a246652a..9d56967559fcd84a21634b5e421fa9e91290b117 100644
--- a/test/test_helpers/test_helpers.py
+++ b/test/test_helpers/test_helpers.py
@@ -10,7 +10,7 @@ import os
import mock
import pytest
-from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items
+from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict
from mlair.helpers import PyTestRegex
from mlair.helpers import Logger, TimeTracking
@@ -22,6 +22,9 @@ class TestToList:
assert to_list('abcd') == ['abcd']
assert to_list([1, 2, 3]) == [1, 2, 3]
assert to_list([45]) == [45]
+ assert to_list({34, 2, "test"}) == [34, 2, "test"]
+ assert to_list((34, 2, "test")) == [34, 2, "test"]
+ assert to_list(("test")) == ["test"]
class TestTimeTracking:
@@ -164,6 +167,22 @@ class TestFloatRound:
assert float_round(-34.9221, 0) == -34.
+class TestSelectFromDict:
+
+ @pytest.fixture
+ def dictionary(self):
+ return {"a": 1, "b": 23, "c": "last"}
+
+ def test_select(self, dictionary):
+ assert select_from_dict(dictionary, "c") == {"c": "last"}
+ assert select_from_dict(dictionary, ["a", "c"]) == {"a": 1, "c": "last"}
+ assert select_from_dict(dictionary, "d") == {}
+
+ def test_select_no_dict_given(self):
+ with pytest.raises(AssertionError):
+ select_from_dict(["we"], "now")
+
+
class TestRemoveItems:
@pytest.fixture
@@ -229,6 +248,11 @@ class TestRemoveItems:
remove_items(custom_list)
assert "remove_items() missing 1 required positional argument: 'items'" in e.value.args[0]
+ def test_remove_not_supported_type(self):
+ with pytest.raises(TypeError) as e:
+ remove_items(23, "test")
+ assert f"remove_items does not support type {type(23)}" in e.value.args[0]
+
class TestLogger:
@@ -272,3 +296,18 @@ class TestLogger:
with pytest.raises(TypeError) as e:
logger.logger_console(1.5)
assert "Level not an integer or a valid string: 1.5" == e.value.args[0]
+
+
+class TestExtractValue:
+
+ def test_extract(self):
+ assert extract_value([1]) == 1
+ assert extract_value([[23]]) == 23
+ assert extract_value([("test")]) == "test"
+ assert extract_value((2,)) == 2
+
+ def test_extract_multiple_elements(self):
+ with pytest.raises(NotImplementedError) as e:
+ extract_value([1, 2, 3])
+ assert "Trying to extract an encapsulated value from objects with more than a single entry is not supported " \
+ "by this function." in e.value.args[0]