diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index d07d8cf1ca70ebdbd864cf58fb3b4a61ff183868..42b66dcb68b184112a321473e3aae250d697c452 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -18,7 +18,7 @@ def to_list(obj: Any) -> List: :return: list containing obj, or obj itself (if obj was already a list) """ - if isinstance(obj, set): + if isinstance(obj, (set, tuple)): obj = list(obj) elif not isinstance(obj, list): obj = [obj] @@ -116,6 +116,12 @@ def select_from_dict(dict_obj: dict, sel_list: Any): def extract_value(encapsulated_value): try: - return extract_value(encapsulated_value[0]) + if isinstance(encapsulated_value, str): + raise TypeError + if len(encapsulated_value) == 1: + return extract_value(encapsulated_value[0]) + else: + raise NotImplementedError("Trying to extract an encapsulated value from objects with more than a single " + "entry is not supported by this function.") except TypeError: return encapsulated_value diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py index 723b4a87d70453327ed6b7e355d3ef78a246652a..9d56967559fcd84a21634b5e421fa9e91290b117 100644 --- a/test/test_helpers/test_helpers.py +++ b/test/test_helpers/test_helpers.py @@ -10,7 +10,7 @@ import os import mock import pytest -from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items +from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict from mlair.helpers import PyTestRegex from mlair.helpers import Logger, TimeTracking @@ -22,6 +22,9 @@ class TestToList: assert to_list('abcd') == ['abcd'] assert to_list([1, 2, 3]) == [1, 2, 3] assert to_list([45]) == [45] + assert to_list({34, 2, "test"}) == [34, 2, "test"] + assert to_list((34, 2, "test")) == [34, 2, "test"] + assert to_list(("test")) == ["test"] class TestTimeTracking: @@ -164,6 +167,22 @@ class TestFloatRound: assert float_round(-34.9221, 0) == -34. +class TestSelectFromDict: + + @pytest.fixture + def dictionary(self): + return {"a": 1, "b": 23, "c": "last"} + + def test_select(self, dictionary): + assert select_from_dict(dictionary, "c") == {"c": "last"} + assert select_from_dict(dictionary, ["a", "c"]) == {"a": 1, "c": "last"} + assert select_from_dict(dictionary, "d") == {} + + def test_select_no_dict_given(self): + with pytest.raises(AssertionError): + select_from_dict(["we"], "now") + + class TestRemoveItems: @pytest.fixture @@ -229,6 +248,11 @@ class TestRemoveItems: remove_items(custom_list) assert "remove_items() missing 1 required positional argument: 'items'" in e.value.args[0] + def test_remove_not_supported_type(self): + with pytest.raises(TypeError) as e: + remove_items(23, "test") + assert f"remove_items does not support type {type(23)}" in e.value.args[0] + class TestLogger: @@ -272,3 +296,18 @@ class TestLogger: with pytest.raises(TypeError) as e: logger.logger_console(1.5) assert "Level not an integer or a valid string: 1.5" == e.value.args[0] + + +class TestExtractValue: + + def test_extract(self): + assert extract_value([1]) == 1 + assert extract_value([[23]]) == 23 + assert extract_value([("test")]) == "test" + assert extract_value((2,)) == 2 + + def test_extract_multiple_elements(self): + with pytest.raises(NotImplementedError) as e: + extract_value([1, 2, 3]) + assert "Trying to extract an encapsulated value from objects with more than a single entry is not supported " \ + "by this function." in e.value.args[0]