import pytest
from src.helpers import *
import logging
import os
import keras
import numpy as np
import mock


class TestToList:

    def test_to_list(self):
        assert to_list('a') == ['a']
        assert to_list('abcd') == ['abcd']
        assert to_list([1, 2, 3]) == [1, 2, 3]
        assert to_list([45]) == [45]


class TestCheckPath:

    def test_check_path_and_create(self, caplog):
        caplog.set_level(logging.INFO)
        path = 'data/test'
        assert not os.path.exists('data/test')
        check_path_and_create(path)
        assert os.path.exists('data/test')
        assert caplog.messages[0] == "Created path: data/test"
        check_path_and_create(path)
        assert caplog.messages[1] == "Path already exists: data/test"
        os.rmdir('data/test')


class TestLoss:

    def test_l_p_loss(self):
        model = keras.Sequential()
        model.add(keras.layers.Lambda(lambda x: x, input_shape=(None, )))
        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
        hist = model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1)
        assert hist.history['loss'][0] == 1.25
        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(3))
        hist = model.fit(np.array([1, 0, -2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1)
        assert hist.history['loss'][0] == 2.25


class TestLearningRateDecay:

    def test_init(self):
        lr_decay = LearningRateDecay()
        assert lr_decay.lr == {'lr': []}
        assert lr_decay.base_lr == 0.01
        assert lr_decay.drop == 0.96
        assert lr_decay.epochs_drop == 8

    def test_check_param(self):
        lr_decay = object.__new__(LearningRateDecay)
        assert lr_decay.check_param(1, "tester") == 1
        assert lr_decay.check_param(0.5, "tester") == 0.5
        with pytest.raises(ValueError) as e:
            lr_decay.check_param(0, "tester")
        assert "tester is out of allowed range (0, 1]: tester=0" in e.value.args[0]
        with pytest.raises(ValueError) as e:
            lr_decay.check_param(1.5, "tester")
        assert "tester is out of allowed range (0, 1]: tester=1.5" in e.value.args[0]
        assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5
        with pytest.raises(ValueError) as e:
            lr_decay.check_param(0, "tester", upper=None)
        assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0]
        assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5
        with pytest.raises(ValueError) as e:
            lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)
        assert "tester is out of allowed range (-inf, 0.2]: tester=0.5" in e.value.args[0]
        assert lr_decay.check_param(10, "tester", upper=None, lower=None)

    def test_on_epoch_begin(self):
        lr_decay = LearningRateDecay(base_lr=0.02, drop=0.95, epochs_drop=2)
        model = keras.Sequential()
        model.add(keras.layers.Dense(1, input_dim=1))
        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
        model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
        assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02*0.95, 0.02*0.95, 0.02*0.95*0.95]


class TestTimeTracking:

    def test_init(self):
        t = TimeTracking()
        assert t.start is not None
        assert t.start < time.time()
        assert t.end is None
        t2 = TimeTracking(start=False)
        assert t2.start is None

    def test__start(self):
        t = TimeTracking(start=False)
        t._start()
        assert t.start < time.time()

    def test__end(self):
        t = TimeTracking()
        t._end()
        assert t.end > t.start

    def test__duration(self):
        t = TimeTracking()
        d1 = t._duration()
        assert d1 > 0
        d2 = t._duration()
        assert d2 > d1
        t._end()
        d3 = t._duration()
        assert d3 > d2
        assert d3 == t._duration()

    def test_repr(self):
        t = TimeTracking()
        t._end()
        duration = t._duration()
        assert t.__repr__().rstrip() == f"{round(duration, 2)}s".rstrip()

    def test_run(self):
        t = TimeTracking(start=False)
        assert t.start is None
        t.run()
        assert t.start is not None

    def test_stop(self):
        t = TimeTracking()
        assert t.end is None
        duration = t.stop()
        assert duration == t._duration()

    def test_duration(self):
        t = TimeTracking()
        duration = t
        assert duration is not None
        duration = t.stop()
        assert duration == t.duration()


class TestPrepareHost:

    @mock.patch("socket.gethostname", side_effect=["linux-gzsx", "ZAM144", "zam347", "jrtest", "jwtest"])
    @mock.patch("os.getlogin", return_value="testUser")
    @mock.patch("os.path.exists", return_value=True)
    def test_prepare_host(self, mock_host, mock_user, mock_path):
        path = prepare_host()
        assert path == "/home/testUser/machinelearningtools"
        path = prepare_host()
        assert path == "/home/testUser/Data/toar_daily/"
        path = prepare_host()
        assert path == "/home/testUser/Data/toar_daily/"
        path = prepare_host()
        assert path == "/p/project/cjjsc42/testUser/DATA/toar_daily/"
        path = prepare_host()
        assert path == "/p/home/jusers/testUser/juwels/intelliaq/DATA/toar_daily/"

    @mock.patch("socket.gethostname", return_value="NotExistingHostName")
    @mock.patch("os.getlogin", return_value="zombie21")
    def test_error_handling(self, mock_user, mock_host):
        with pytest.raises(OSError) as e:
            prepare_host()
        assert "unknown host 'NotExistingHostName'" in e.value.args[0]
        mock_host.return_value = "linux-gzsx"
        with pytest.raises(NotADirectoryError) as e:
            prepare_host()
        assert "path '/home/zombie21/machinelearningtools' does not exist for host 'linux-gzsx'" in e.value.args[0]


class TestSetExperimentName:

    def test_set_experiment(self):
        exp_name, exp_path = set_experiment_name()
        assert exp_name == ""
        assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ""))
        exp_name, exp_path = set_experiment_name(experiment_date="2019-11-14", experiment_path="./test2")
        assert exp_name == "2019-11-14_network/"
        assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test2"))

    def test_set_experiment_from_sys(self):
        exp_name, _ = set_experiment_name(experiment_date="2019-11-14")
        assert exp_name == "2019-11-14_network/"