diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index 36470ebc1c3a008c0f6ecca11478d83d6fa57cec..d07d8cf1ca70ebdbd864cf58fb3b4a61ff183868 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -12,13 +12,15 @@ from typing import Dict, Callable, Union, List, Any def to_list(obj: Any) -> List: """ - Transform given object to list if obj is not already a list. + Transform given object to list if obj is not already a list. Sets are also transformed to a list. :param obj: object to transform to list :return: list containing obj, or obj itself (if obj was already a list) """ - if not isinstance(obj, list): + if isinstance(obj, set): + obj = list(obj) + elif not isinstance(obj, list): obj = [obj] return obj @@ -99,7 +101,7 @@ def remove_items(obj: Union[List, Dict], items: Any): raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.") -def select_from_dict(dict_obj: dict, sel_list: str): +def select_from_dict(dict_obj: dict, sel_list: Any): """ Extract all key values pairs whose key is contained in the sel_list.