import numpy as np
import xarray as xr
import dask.array as da

import datetime as dt
import logging
import math
import time
import os

import mock
import pytest
import string

from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict
from mlair.helpers import PyTestRegex
from mlair.helpers import Logger, TimeTracking
from mlair.helpers.helpers import is_xarray, convert2xrda


class TestToList:

    def test_to_list(self):
        assert to_list('a') == ['a']
        assert to_list('abcd') == ['abcd']
        assert to_list([1, 2, 3]) == [1, 2, 3]
        assert to_list([45]) == [45]
        s = {34, 2, "test"}
        assert to_list(s) == list(s)
        assert to_list((34, 2, "test")) == [34, 2, "test"]
        assert to_list(("test")) == ["test"]


class TestTimeTracking:

    def test_init(self):
        t = TimeTracking()
        assert t.start is not None
        assert t.start < time.time()
        assert t.end is None
        t2 = TimeTracking(start=False)
        assert t2.start is None

    def test__start(self):
        t = TimeTracking(start=False)
        t._start()
        assert t.start < time.time()

    def test__end(self):
        t = TimeTracking()
        t._end()
        assert t.end > t.start

    def test__duration(self):
        t = TimeTracking()
        d1 = t._duration()
        assert d1 > 0
        d2 = t._duration()
        assert d2 > d1
        t._end()
        d3 = t._duration()
        assert d3 > d2
        assert d3 == t._duration()

    def test_repr(self):
        t = TimeTracking()
        t._end()
        duration = t._duration()
        assert t.__repr__().rstrip() == f"{dt.timedelta(seconds=math.ceil(duration))} (hh:mm:ss)".rstrip()

    def test_run(self):
        t = TimeTracking(start=False)
        assert t.start is None
        t.run()
        assert t.start is not None

    def test_stop(self):
        t = TimeTracking()
        assert t.end is None
        duration = t.stop(get_duration=True)
        assert duration == t._duration()
        with pytest.raises(AssertionError) as e:
            t.stop()
        assert "Time was already stopped" in e.value.args[0]
        t.run()
        assert t.end is None
        assert t.stop() is None
        assert t.end is not None

    def test_duration(self):
        t = TimeTracking()
        duration = t
        assert duration is not None
        duration = t.stop(get_duration=True)
        assert duration == t.duration()

    def test_enter_exit(self, caplog):
        caplog.set_level(logging.INFO)
        with TimeTracking() as t:
            assert t.start is not None
            assert t.end is None
        expression = PyTestRegex(r"undefined job finished after \d+:\d+:\d+ \(hh:mm:ss\)")
        assert caplog.record_tuples[-1] == ('root', 20, expression)

    def test_name_enter_exit(self, caplog):
        caplog.set_level(logging.INFO)
        with TimeTracking(name="my job") as t:
            assert t.start is not None
            assert t.end is None
        expression = PyTestRegex(r"my job finished after \d+:\d+:\d+ \(hh:mm:ss\)")
        assert caplog.record_tuples[-1] == ('root', 20, expression)


class TestPytestRegex:

    @pytest.fixture
    def regex(self):
        return PyTestRegex("teststring")

    def test_pytest_regex_init(self, regex):
        assert regex._regex.pattern == "teststring"

    def test_pytest_regex_eq(self, regex):
        assert regex == "teststringabcd"
        assert regex != "teststgabcd"

    def test_pytest_regex_repr(self, regex):
        assert regex.__repr__() == "teststring"


class TestDictToXarray:

    def test_dict_to_xarray(self):
        array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
        array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
        d = {"number1": array1, "number2": array2}
        res = dict_to_xarray(d, "merge_dim")
        assert type(res) == xr.DataArray
        assert sorted(list(res.coords)) == ["merge_dim", "x", "y"]
        assert res.shape == (2, 2, 3)

    def test_dict_to_xarray_single_entry(self):
        array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
        d = {"number1": array1}
        res = dict_to_xarray(d, "merge_dim")
        assert type(res) == xr.DataArray
        assert sorted(list(res.coords)) == ["merge_dim", "x", "y"]
        assert res.shape == (1, 2, 3)


class TestFloatRound:

    def test_float_round_ceil(self):
        assert float_round(4.6) == 5
        assert float_round(239.3992) == 240

    def test_float_round_decimals(self):
        assert float_round(23.0091, 2) == 23.01
        assert float_round(23.1091, 3) == 23.11

    def test_float_round_type(self):
        assert float_round(34.9221, 2, math.floor) == 34.92
        assert float_round(34.9221, 0, math.floor) == 34.
        assert float_round(34.9221, 2, round) == 34.92
        assert float_round(34.9221, 0, round) == 35.

    def test_float_round_negative(self):
        assert float_round(-34.9221, 2, math.floor) == -34.93
        assert float_round(-34.9221, 0, math.floor) == -35.
        assert float_round(-34.9221, 2) == -34.92
        assert float_round(-34.9221, 0) == -34.


class TestSelectFromDict:

    @pytest.fixture
    def dictionary(self):
        return {"a": 1, "b": 23, "c": "last", "e": None}

    def test_select(self, dictionary):
        assert select_from_dict(dictionary, "c") == {"c": "last"}
        assert select_from_dict(dictionary, ["a", "c"]) == {"a": 1, "c": "last"}
        assert select_from_dict(dictionary, "d") == {}

    def test_select_no_dict_given(self):
        with pytest.raises(AssertionError):
            select_from_dict(["we"], "now")

    def test_select_remove_none(self, dictionary):
        assert select_from_dict(dictionary, ["a", "e"]) == {"a": 1, "e": None}
        assert select_from_dict(dictionary, ["a", "e"], remove_none=True) == {"a": 1}


class TestRemoveItems:

    @pytest.fixture
    def custom_list(self):
        return [1, 2, 3, 'a', 'bc']

    @pytest.fixture
    def custom_dict(self):
        return {'a': 1, 'b': 2, 2: 'ab'}

    def test_dict_remove_single(self, custom_dict):
        # one out as list
        d_pop = remove_items(custom_dict, [4])
        assert d_pop == custom_dict
        # one out as str
        d_pop = remove_items(custom_dict, '4')
        assert d_pop == custom_dict
        # one in as str
        d_pop = remove_items(custom_dict, 'b')
        assert d_pop == {'a': 1, 2: 'ab'}
        # one in as list
        d_pop = remove_items(custom_dict, ['b'])
        assert d_pop == {'a': 1, 2: 'ab'}

    def test_dict_remove_multiple(self, custom_dict):
        # all out (list)
        d_pop = remove_items(custom_dict, [4, 'mykey'])
        assert d_pop == custom_dict
        # all in (list)
        d_pop = remove_items(custom_dict, ['a', 2])
        assert d_pop == {'b': 2}
        # one in one out (list)
        d_pop = remove_items(custom_dict, [2, '10'])
        assert d_pop == {'a': 1, 'b': 2}

    def test_list_remove_single(self, custom_list):
        l_pop = remove_items(custom_list, 1)
        assert l_pop == [2, 3, 'a', 'bc']
        l_pop = remove_items(custom_list, 'bc')
        assert l_pop == [1, 2, 3, 'a']
        l_pop = remove_items(custom_list, 5)
        assert l_pop == custom_list

    def test_list_remove_multiple(self, custom_list):
        # all in list
        l_pop = remove_items(custom_list, [2, 'a'])
        assert l_pop == [1, 3, 'bc']
        # one in one out
        l_pop = remove_items(custom_list, ['bc', 10])
        assert l_pop == [1, 2, 3, 'a']
        # all out
        l_pop = remove_items(custom_list, [10, 'aa'])
        assert l_pop == custom_list

    def test_remove_missing_argument(self, custom_dict, custom_list):
        with pytest.raises(TypeError) as e:
            remove_items()
        assert "remove_items() missing 2 required positional arguments: 'obj' and 'items'" in e.value.args[0]
        with pytest.raises(TypeError) as e:
            remove_items(custom_dict)
        assert "remove_items() missing 1 required positional argument: 'items'" in e.value.args[0]
        with pytest.raises(TypeError) as e:
            remove_items(custom_list)
        assert "remove_items() missing 1 required positional argument: 'items'" in e.value.args[0]

    def test_remove_not_supported_type(self):
        with pytest.raises(TypeError) as e:
            remove_items(23, "test")
        assert f"remove_items does not support type {type(23)}" in e.value.args[0]


class TestLogger:

    @pytest.fixture
    def logger(self):
        return Logger()

    def test_init_default(self):
        log = Logger()
        assert log.formatter == "%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]"
        assert log.log_file == Logger.setup_logging_path()
        # assert PyTestRegex(
        #     ".*machinelearningtools/src/\.{2}/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log.log_file

    def test_setup_logging_path_none(self):
        log_file = Logger.setup_logging_path(None)
        test_regex = os.getcwd() + r"/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log"
        assert PyTestRegex(test_regex) == log_file

    @mock.patch("os.makedirs", side_effect=None)
    def test_setup_logging_path_given(self, mock_makedirs):
        path = "my/test/path"
        log_path = Logger.setup_logging_path(path)
        assert PyTestRegex("my/test/path/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_path

    def test_logger_console_level0(self, logger):
        consol = logger.logger_console(0)
        assert isinstance(consol, logging.StreamHandler)
        assert consol.level == 0
        formatter = logging.Formatter(logger.formatter)
        assert isinstance(formatter, logging.Formatter)

    def test_logger_console_level1(self, logger):
        consol = logger.logger_console(1)
        assert isinstance(consol, logging.StreamHandler)
        assert consol.level == 1
        formatter = logging.Formatter(logger.formatter)
        assert isinstance(formatter, logging.Formatter)

    def test_logger_console_level_wrong_type(self, logger):
        with pytest.raises(TypeError) as e:
            logger.logger_console(1.5)
        assert "Level not an integer or a valid string: 1.5" == e.value.args[0]


class TestExtractValue:

    def test_extract(self):
        assert extract_value([1]) == 1
        assert extract_value([[23]]) == 23
        assert extract_value([("test")]) == "test"
        assert extract_value((2,)) == 2

    def test_extract_multiple_elements(self):
        with pytest.raises(NotImplementedError) as e:
            extract_value([1, 2, 3])
        assert "Trying to extract an encapsulated value from objects with more than a single entry is not supported " \
               "by this function." in e.value.args[0]


class TestIsXarray:

    @pytest.fixture
    def custom_xr_data(self):
        return xr.DataArray(np.array(range(5)))

    def test_is_xarray_xr_input(self, custom_xr_data):
        # data_array = xr.DataArray(np.array(range(5)))
        assert is_xarray(custom_xr_data) is True
        assert is_xarray(xr.Dataset({'test': custom_xr_data})) is True

    def test_is_xarray_other_input(self, custom_xr_data):
        assert is_xarray(1) is False
        assert is_xarray(1.) is False
        assert is_xarray([1, 2.]) is False
        assert is_xarray([custom_xr_data]) is False


class TestConvert2xrDa:

    @pytest.fixture
    def custom_1d_npdata(self):
        return np.array(range(9))

    @pytest.fixture()
    def custom_2d_npdata(self, custom_1d_npdata):
        return np.stack([custom_1d_npdata, 2 * custom_1d_npdata])

    @pytest.fixture
    def custom_xr_dataarray(self, custom_1d_npdata):
        return xr.DataArray(custom_1d_npdata)

    @pytest.fixture
    def custom_xr_dataset(self, custom_xr_dataarray):
        return xr.Dataset({'test_1': custom_xr_dataarray})

    @pytest.fixture
    def custom_1d_daarray(self, custom_1d_npdata):
        return da.array(custom_1d_npdata)

    def test_convert2xrda_xrdata_in(self, custom_xr_dataarray, custom_xr_dataset):
        assert (convert2xrda(custom_xr_dataarray) == custom_xr_dataarray).all()
        assert (convert2xrda(custom_xr_dataset) == custom_xr_dataset).all()

    def test_convert2xrda_npdata_in_nokwargs(self, custom_1d_npdata, custom_2d_npdata):
        converted_data = convert2xrda(custom_1d_npdata)
        assert isinstance(converted_data, xr.DataArray)
        assert (converted_data.values == custom_1d_npdata).all()
        assert converted_data.dims == ('dim_0',)
        assert converted_data.dim_0.size == custom_1d_npdata.shape[0]

        # Feed in a 2D-np.array without additional kwargs
        converted_data = convert2xrda(custom_2d_npdata)
        assert isinstance(converted_data, xr.DataArray)
        assert (converted_data.values == custom_2d_npdata).all()
        assert converted_data.dims == ('dim_0', 'dim_1')
        assert converted_data.dim_0.size == custom_2d_npdata.shape[0]
        assert converted_data.dim_1.size == custom_2d_npdata.shape[1]

    def test_convert2xrda_npdata_in_nokwargs_default_true(self, custom_1d_npdata, custom_2d_npdata):
        converted_data = convert2xrda(custom_1d_npdata, use_1d_default=True)
        assert isinstance(converted_data, xr.DataArray)
        assert (converted_data.values == custom_1d_npdata).all()
        assert converted_data.dims == ('points',)
        assert converted_data.points.size == custom_1d_npdata.shape[0]

        # Feed in a 2D-np.array without additional kwargs
        with pytest.raises(ValueError) as e:
            converted_data = convert2xrda(custom_2d_npdata, use_1d_default=True)
        assert "different number of dimensions on data and dims: 2 vs 1" in e.value.args[0]

    @pytest.mark.parametrize("use_1d_default", (False, True))
    def test_convert2xrda_npdata_in_kwargs(self, custom_1d_npdata, custom_2d_npdata, use_1d_default):
        converted_data = convert2xrda(custom_1d_npdata, use_1d_default=use_1d_default, dims='other_points')
        assert isinstance(converted_data, xr.DataArray)
        assert (converted_data.values == custom_1d_npdata).all()
        assert converted_data.dims == ('other_points',)
        assert converted_data.other_points.size == custom_1d_npdata.shape[0]

        # Feed in a 2D-np.array with correct additional kwargs
        converted_data = convert2xrda(custom_2d_npdata, use_1d_default=use_1d_default,
                                      dims=['test_dim_0', 'test_dim_1'],
                                      coords={'test_dim_0': list(string.ascii_lowercase[:custom_2d_npdata.shape[0]]),
                                              'test_dim_1': list(string.ascii_lowercase[:custom_2d_npdata.shape[1]]),
                                              },
                                      )
        assert isinstance(converted_data, xr.DataArray)
        assert (converted_data.values == custom_2d_npdata).all()
        assert converted_data.dims == ('test_dim_0', 'test_dim_1')
        assert (converted_data.coords['test_dim_0'].values == np.array(['a', 'b'])).all()

    @pytest.mark.parametrize("scalar", (1, 2.))
    def test_convert2xrda_int_float_in_nokwargs_default_true(self, scalar):
        converted_data = convert2xrda(scalar, use_1d_default=True)
        assert isinstance(converted_data, xr.DataArray)
        assert converted_data.values == np.array([scalar])
        assert converted_data.dims == ('points',)

    @pytest.mark.parametrize("wrong_input", ({1: 'b'}, [1], 'abc'))
    def test_convert2xrda_wrong_type_in_default_true_nokwargs(self, wrong_input):
        with pytest.raises(TypeError) as e:
            converted_data = convert2xrda(wrong_input, use_1d_default=True)
        assert f"`arr' must be arry-like, int or float. But is of type {type(wrong_input)}" in e.value.args[0]

    def test_convert2xrda_dask_in_default_true_nokwargs(self, custom_1d_daarray):
        with pytest.raises(TypeError) as e:
            convert2xrda(custom_1d_daarray, True)
        assert "`use_1d_default=True' is used with `arr' of type da.array. For da.arrays please pass `use_1d_default=False' and specify keywords for xr.DataArray via kwargs." in \
               e.value.args[0]
        assert "`use_1d_default=True' is used with `arr' of type da.array. For da.arrays please pass" + \
               " `use_1d_default=False' and specify keywords for xr.DataArray via kwargs." in e.value.args[0]