import logging
import os
import platform

import keras
import mock
import numpy as np
import pytest

import re

from src.helpers import *


class TestToList:

    def test_to_list(self):
        assert to_list('a') == ['a']
        assert to_list('abcd') == ['abcd']
        assert to_list([1, 2, 3]) == [1, 2, 3]
        assert to_list([45]) == [45]


class TestCheckPath:

    def test_check_path_and_create(self, caplog):
        caplog.set_level(logging.DEBUG)
        path = 'data/test'
        assert not os.path.exists('data/test')
        check_path_and_create(path)
        assert os.path.exists('data/test')
        assert caplog.messages[0] == "Created path: data/test"
        check_path_and_create(path)
        assert caplog.messages[1] == "Path already exists: data/test"
        os.rmdir('data/test')


class TestLoss:

    def test_l_p_loss(self):
        model = keras.Sequential()
        model.add(keras.layers.Lambda(lambda x: x, input_shape=(None,)))
        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
        hist = model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1)
        assert hist.history['loss'][0] == 1.25
        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(3))
        hist = model.fit(np.array([1, 0, -2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1)
        assert hist.history['loss'][0] == 2.25


class TestTimeTracking:

    def test_init(self):
        t = TimeTracking()
        assert t.start is not None
        assert t.start < time.time()
        assert t.end is None
        t2 = TimeTracking(start=False)
        assert t2.start is None

    def test__start(self):
        t = TimeTracking(start=False)
        t._start()
        assert t.start < time.time()

    def test__end(self):
        t = TimeTracking()
        t._end()
        assert t.end > t.start

    def test__duration(self):
        t = TimeTracking()
        d1 = t._duration()
        assert d1 > 0
        d2 = t._duration()
        assert d2 > d1
        t._end()
        d3 = t._duration()
        assert d3 > d2
        assert d3 == t._duration()

    def test_repr(self):
        t = TimeTracking()
        t._end()
        duration = t._duration()
        assert t.__repr__().rstrip() == f"{dt.timedelta(seconds=math.ceil(duration))} (hh:mm:ss)".rstrip()

    def test_run(self):
        t = TimeTracking(start=False)
        assert t.start is None
        t.run()
        assert t.start is not None

    def test_stop(self):
        t = TimeTracking()
        assert t.end is None
        duration = t.stop(get_duration=True)
        assert duration == t._duration()
        with pytest.raises(AssertionError) as e:
            t.stop()
        assert "Time was already stopped" in e.value.args[0]
        t.run()
        assert t.end is None
        assert t.stop() is None
        assert t.end is not None

    def test_duration(self):
        t = TimeTracking()
        duration = t
        assert duration is not None
        duration = t.stop(get_duration=True)
        assert duration == t.duration()

    def test_enter_exit(self, caplog):
        caplog.set_level(logging.INFO)
        with TimeTracking() as t:
            assert t.start is not None
            assert t.end is None
        expression = PyTestRegex(r"undefined job finished after \d+:\d+:\d+ \(hh:mm:ss\)")
        assert caplog.record_tuples[-1] == ('root', 20, expression)

    def test_name_enter_exit(self, caplog):
        caplog.set_level(logging.INFO)
        with TimeTracking(name="my job") as t:
            assert t.start is not None
            assert t.end is None
        expression = PyTestRegex(r"my job finished after \d+:\d+:\d+ \(hh:mm:ss\)")
        assert caplog.record_tuples[-1] == ('root', 20, expression)


class TestPrepareHost:

    @mock.patch("socket.gethostname", side_effect=["linux-aa9b", "ZAM144", "zam347", "jrtest", "jwtest",
                                                   "runner-6HmDp9Qd-project-2411-concurrent"])
    @mock.patch("os.getlogin", return_value="testUser")
    @mock.patch("os.path.exists", return_value=True)
    def test_prepare_host(self, mock_host, mock_user, mock_path):
        assert prepare_host() == "/home/testUser/machinelearningtools/data/toar_daily/"
        assert prepare_host() == "/home/testUser/Data/toar_daily/"
        assert prepare_host() == "/home/testUser/Data/toar_daily/"
        assert prepare_host() == "/p/project/cjjsc42/testUser/DATA/toar_daily/"
        assert prepare_host() == "/p/home/jusers/testUser/juwels/intelliaq/DATA/toar_daily/"
        assert prepare_host() == '/home/testUser/machinelearningtools/data/toar_daily/'

    @mock.patch("socket.gethostname", return_value="NotExistingHostName")
    @mock.patch("os.getlogin", return_value="zombie21")
    def test_error_handling_unknown_host(self, mock_user, mock_host):
        with pytest.raises(OSError) as e:
            prepare_host()
        assert "unknown host 'NotExistingHostName'" in e.value.args[0]

    @mock.patch("os.getlogin", return_value="zombie21")
    @mock.patch("src.helpers.check_path_and_create", side_effect=PermissionError)
    def test_error_handling(self, mock_cpath, mock_user):
        # if "runner-6HmDp9Qd-project-2411-concurrent" not in platform.node():
        # mock_host.return_value = "linux-aa9b"
        with pytest.raises(NotADirectoryError) as e:
            prepare_host()
        assert PyTestRegex(r"path '.*' does not exist for host '.*'\.") == e.value.args[0]
        with pytest.raises(NotADirectoryError) as e:
            prepare_host(False)
        # assert "does not exist for host 'linux-aa9b'" in e.value.args[0]
        assert PyTestRegex(r"path '.*' does not exist for host '.*'\.") == e.value.args[0]

    @mock.patch("socket.gethostname", side_effect=["linux-aa9b", "ZAM144", "zam347", "jrtest", "jwtest",
                                                   "runner-6HmDp9Qd-project-2411-concurrent"])
    @mock.patch("os.getlogin", side_effect=OSError)
    @mock.patch("os.path.exists", return_value=True)
    def test_os_error(self, mock_path, mock_user, mock_host):
        path = prepare_host()
        assert path == "/home/default/machinelearningtools/data/toar_daily/"
        path = prepare_host()
        assert path == "/home/default/Data/toar_daily/"
        path = prepare_host()
        assert path == "/home/default/Data/toar_daily/"
        path = prepare_host()
        assert path == "/p/project/cjjsc42/default/DATA/toar_daily/"
        path = prepare_host()
        assert path == "/p/home/jusers/default/juwels/intelliaq/DATA/toar_daily/"
        path = prepare_host()
        assert path == '/home/default/machinelearningtools/data/toar_daily/'

    @mock.patch("socket.gethostname", side_effect=["linux-aa9b"])
    @mock.patch("os.getlogin", return_value="testUser")
    @mock.patch("os.path.exists", return_value=False)
    @mock.patch("os.makedirs", side_effect=None)
    def test_os_path_exists(self, mock_host, mock_user, mock_path, mock_check):
        path = prepare_host()
        assert path == "/home/testUser/machinelearningtools/data/toar_daily/"


class TestSetExperimentName:

    def test_set_experiment(self):
        exp_name, exp_path = set_experiment_name()
        assert exp_name == "TestExperiment"
        assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "TestExperiment"))
        exp_name, exp_path = set_experiment_name(experiment_date="2019-11-14", experiment_path="./test2")
        assert exp_name == "2019-11-14_network"
        assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test2", exp_name))

    def test_set_experiment_from_sys(self):
        exp_name, _ = set_experiment_name(experiment_date="2019-11-14")
        assert exp_name == "2019-11-14_network"

    def test_set_expperiment_hourly(self):
        exp_name, exp_path = set_experiment_name(sampling="hourly")
        assert exp_name == "TestExperiment_hourly"
        assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "TestExperiment_hourly"))


class TestSetBootstrapPath:

    def test_bootstrap_path_is_none(self):
        bootstrap_path = set_bootstrap_path(None, 'TestDataPath/', 'daily')
        assert bootstrap_path == 'TestDataPath/../bootstrap_daily'

    @mock.patch("os.makedirs", side_effect=None)
    def test_bootstap_path_is_given(self, mock_makedir):
        bootstrap_path = set_bootstrap_path('Test/path/to/boots', None, None)
        assert bootstrap_path == 'Test/path/to/boots'


class TestPytestRegex:

    @pytest.fixture
    def regex(self):
        return PyTestRegex("teststring")

    def test_pytest_regex_init(self, regex):
        assert regex._regex.pattern == "teststring"

    def test_pytest_regex_eq(self, regex):
        assert regex == "teststringabcd"
        assert regex != "teststgabcd"

    def test_pytest_regex_repr(self, regex):
        assert regex.__repr__() == "teststring"


class TestDictToXarray:

    def test_dict_to_xarray(self):
        array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]})
        array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]})
        d = {"number1": array1, "number2": array2}
        res = dict_to_xarray(d, "merge_dim")
        assert type(res) == xr.DataArray
        assert sorted(list(res.coords)) == ["merge_dim", "x"]
        assert res.shape == (2, 2, 3)


class TestFloatRound:

    def test_float_round_ceil(self):
        assert float_round(4.6) == 5
        assert float_round(239.3992) == 240

    def test_float_round_decimals(self):
        assert float_round(23.0091, 2) == 23.01
        assert float_round(23.1091, 3) == 23.11

    def test_float_round_type(self):
        assert float_round(34.9221, 2, math.floor) == 34.92
        assert float_round(34.9221, 0, math.floor) == 34.
        assert float_round(34.9221, 2, round) == 34.92
        assert float_round(34.9221, 0, round) == 35.

    def test_float_round_negative(self):
        assert float_round(-34.9221, 2, math.floor) == -34.93
        assert float_round(-34.9221, 0, math.floor) == -35.
        assert float_round(-34.9221, 2) == -34.92
        assert float_round(-34.9221, 0) == -34.


class TestDictPop:

    @pytest.fixture
    def custom_dict(self):
        return {'a': 1, 'b': 2, 2: 'ab'}

    def test_dict_pop_single(self, custom_dict):
        # one out as list
        d_pop = dict_pop(custom_dict, [4])
        assert d_pop == custom_dict
        # one out as str
        d_pop = dict_pop(custom_dict, '4')
        assert d_pop == custom_dict
        # one in as str
        d_pop = dict_pop(custom_dict, 'b')
        assert d_pop == {'a': 1, 2: 'ab'}
        # one in as list
        d_pop = dict_pop(custom_dict, ['b'])
        assert d_pop == {'a': 1, 2: 'ab'}

    def test_dict_pop_multiple(self, custom_dict):
        # all out (list)
        d_pop = dict_pop(custom_dict, [4, 'mykey'])
        assert d_pop == custom_dict
        # all in (list)
        d_pop = dict_pop(custom_dict, ['a', 2])
        assert d_pop == {'b': 2}
        # one in one out (list)
        d_pop = dict_pop(custom_dict, [2, '10'])
        assert d_pop == {'a': 1, 'b': 2}

    def test_dict_pop_missing_argument(self, custom_dict):
        with pytest.raises(TypeError) as e:
            dict_pop()
        assert "dict_pop() missing 2 required positional arguments: 'dict_orig' and 'pop_keys'" in e.value.args[0]
        with pytest.raises(TypeError) as e:
            dict_pop(custom_dict)
        assert "dict_pop() missing 1 required positional argument: 'pop_keys'" in e.value.args[0]


class TestListPop:

    @pytest.fixture
    def custom_list(self):
        return [1, 2, 3, 'a', 'bc']

    def test_list_pop_single(self, custom_list):
        l_pop = list_pop(custom_list, 1)
        assert l_pop == [2, 3, 'a', 'bc']
        l_pop = list_pop(custom_list, 'bc')
        assert l_pop == [1, 2, 3, 'a']
        l_pop = list_pop(custom_list, 5)
        assert l_pop == custom_list

    def test_list_pop_multiple(self, custom_list):
        # all in list
        l_pop = list_pop(custom_list, [2, 'a'])
        assert l_pop == [1, 3, 'bc']
        # one in one out
        l_pop = list_pop(custom_list, ['bc', 10])
        assert l_pop == [1, 2, 3, 'a']
        # all out
        l_pop = list_pop(custom_list, [10, 'aa'])
        assert l_pop == custom_list

    def test_list_pop_missing_argument(self, custom_list):
        with pytest.raises(TypeError) as e:
            list_pop()
        assert "list_pop() missing 2 required positional arguments: 'list_full' and 'pop_items'" in e.value.args[0]
        with pytest.raises(TypeError) as e:
            list_pop(custom_list)
        assert "list_pop() missing 1 required positional argument: 'pop_items'" in e.value.args[0]


class TestLogger:

    @pytest.fixture
    def logger(self):
        return Logger()

    def test_init_default(self):
        log = Logger()
        assert log.formatter == "%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]"
        assert log.log_file == Logger.setup_logging_path()
        # assert PyTestRegex(
        #     ".*machinelearningtools/src/\.{2}/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log.log_file

    def test_setup_logging_path_none(self):
        log_file = Logger.setup_logging_path(None)
        assert PyTestRegex(
            ".*machinelearningtools/src/\.{2}/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_file

    @mock.patch("os.makedirs", side_effect=None)
    def test_setup_logging_path_given(self, mock_makedirs):
        path = "my/test/path"
        log_path = Logger.setup_logging_path(path)
        assert PyTestRegex("my/test/path/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_path

    def test_logger_console_level0(self, logger):
        consol = logger.logger_console(0)
        assert isinstance(consol, logging.StreamHandler)
        assert consol.level == 0
        formatter = logging.Formatter(logger.formatter)
        assert isinstance(formatter, logging.Formatter)

    def test_logger_console_level1(self, logger):
        consol = logger.logger_console(1)
        assert isinstance(consol, logging.StreamHandler)
        assert consol.level == 1
        formatter = logging.Formatter(logger.formatter)
        assert isinstance(formatter, logging.Formatter)

    def test_logger_console_level_wrong_type(self, logger):
        with pytest.raises(TypeError) as e:
            logger.logger_console(1.5)
        assert "Level not an integer or a valid string: 1.5" == e.value.args[0]