"""Collection of different help functions."""
__author__ = 'Lukas Leufen, Felix Kleinert'
__date__ = '2019-10-21'

import inspect
import math

import xarray as xr

from typing import Dict, Callable, Union, List, Any


def to_list(obj: Any) -> List:
    """
    Transform given object to list if obj is not already a list.

    :param obj: object to transform to list

    :return: list containing obj, or obj itself (if obj was already a list)
    """
    if not isinstance(obj, list):
        obj = [obj]
    return obj


def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray:
    """
    Convert a dictionary of 2D-xarrays to single 3D-xarray. The name of new coordinate axis follows <coordinate_name>.

    :param d: dictionary with 2D-xarrays
    :param coordinate_name: name of the new created axis (2D -> 3D)

    :return: combined xarray
    """
    xarray = None
    for k, v in d.items():
        if xarray is None:
            xarray = v
            xarray.coords[coordinate_name] = k
        else:
            tmp_xarray = v
            tmp_xarray.coords[coordinate_name] = k
            xarray = xr.concat([xarray, tmp_xarray], coordinate_name)
    return xarray


def float_round(number: float, decimals: int = 0, round_type: Callable = math.ceil) -> float:
    """
    Perform given rounding operation on number with the precision of decimals.

    :param number: the number to round
    :param decimals: numbers of decimals of the rounding operations (default 0 -> round to next integer value)
    :param round_type: the actual rounding operation. Can be any callable function like math.ceil, math.floor or python
        built-in round operation.

    :return: rounded number with desired precision
    """
    multiplier = 10. ** decimals
    return round_type(number * multiplier) / multiplier


def remove_items(obj: Union[List, Dict], items: Any):
    """
    Remove item(s) from either list or dictionary.

    :param obj: object to remove items from (either dictionary or list)
    :param items: elements to remove from obj. Can either be a list or single entry / key

    :return: object without items
    """

    def remove_from_list(list_obj, item_list):
        """Remove implementation for lists."""
        if len(items) > 1:
            return [e for e in list_obj if e not in item_list]
        else:
            list_obj = list_obj.copy()
            try:
                list_obj.remove(item_list[0])
            except ValueError:
                pass
            return list_obj

    def remove_from_dict(dict_obj, key_list):
        """Remove implementation for dictionaries."""
        return {k: v for k, v in dict_obj.items() if k not in key_list}

    items = to_list(items)
    if isinstance(obj, list):
        return remove_from_list(obj, items)
    elif isinstance(obj, dict):
        return remove_from_dict(obj, items)
    else:
        raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.")