import pytest

from collections import OrderedDict
import os
import shutil

from matplotlib import pyplot as plt
import numpy as np

from src.plotting.tracker_plot import TrackObject, TrackChain, TrackPlot
from src.helpers import PyTestAllEqual


class TestTrackObject:

    @pytest.fixture
    def track_obj(self):
        return TrackObject("custom_name", "your_stage")

    def test_init(self, track_obj):
        assert track_obj.name == ["custom_name"]
        assert track_obj.stage == "your_stage"
        assert all(track_obj.__getattribute__(obj) is None for obj in ["precursor", "successor", "x", "y"])

    def test_repr(self, track_obj):
        track_obj.name = ["custom", "name"]
        assert repr(track_obj) == "custom/name"

    def test_x_property(self, track_obj):
        assert track_obj.x is None
        track_obj.x = 23
        assert track_obj.x == 23

    def test_y_property(self, track_obj):
        assert track_obj.y is None
        track_obj.y = 21
        assert track_obj.y == 21

    def test_add_precursor(self, track_obj):
        assert track_obj.precursor is None
        another_track_obj = TrackObject(["another", "track"], "your_stage")
        track_obj.add_precursor(another_track_obj)
        assert isinstance(track_obj.precursor, list)
        assert track_obj.precursor[-1] == another_track_obj
        assert len(track_obj.precursor) == 1
        assert another_track_obj.successor is not None
        track_obj.add_precursor(another_track_obj)
        assert len(track_obj.precursor) == 1
        track_obj.add_precursor(TrackObject(["third", "track"], "your_stage"))
        assert len(track_obj.precursor) == 2

    def test_add_successor(self, track_obj):
        assert track_obj.successor is None
        another_track_obj = TrackObject(["another", "track"], "your_stage")
        track_obj.add_successor(another_track_obj)
        assert isinstance(track_obj.successor, list)
        assert track_obj.successor[-1] == another_track_obj
        assert len(track_obj.successor) == 1
        assert another_track_obj.precursor is not None
        track_obj.add_successor(another_track_obj)
        assert len(track_obj.successor) == 1
        track_obj.add_successor(TrackObject(["third", "track"], "your_stage"))
        assert len(track_obj.successor) == 2


class TestTrackChain:

    @pytest.fixture
    def track_list(self):
        return [{'Stage1': {'test': [{'method': 'set', 'scope': 'general.daytime'},
                                     {'method': 'set', 'scope': 'general'},
                                     {'method': 'get', 'scope': 'general'},
                                     {'method': 'get', 'scope': 'general'},],
                            'another': [{'method': 'set', 'scope': 'general'}]}},
                {'Stage2': {'sunlight': [{'method': 'set', 'scope': 'general'}],
                            'another': [{'method': 'get', 'scope': 'general.daytime'},
                                        {'method': 'set', 'scope': 'general'},
                                        {'method': 'set', 'scope': 'general.daytime'},
                                        {'method': 'get', 'scope': 'general.daytime.noon'},
                                        {'method': 'get', 'scope': 'general.nighttime'},
                                        {'method': 'get', 'scope': 'general.daytime.noon'}]}},
                {'Stage3': {'another': [{'method': 'get', 'scope': 'general.daytime'}],
                            'test': [{'method': 'get', 'scope': 'general'}],
                            'moonlight': [{'method': 'set', 'scope': 'general.daytime'}]}}]

    @pytest.fixture
    def track_chain(self, track_list):
        return TrackChain(track_list)

    @pytest.fixture
    def track_chain_object(self):
        return object.__new__(TrackChain)

    def test_init(self, track_list):
        chain = TrackChain(track_list)
        assert chain.track_list == track_list

    def test_get_all_scopes(self, track_chain, track_list):
        scopes = track_chain.get_all_scopes(track_list)
        expected_scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"],
                           "moonlight": ["general", "general.daytime"],
                           "sunlight": ["general"],
                           "test": ["general", "general.daytime"]}
        assert scopes == expected_scopes

    def test_get_unique_scopes(self, track_chain_object):
        variable_calls = [{'method': 'get', 'scope': 'general.daytime'},
                          {'method': 'set', 'scope': 'general'},
                          {'method': 'set', 'scope': 'general.daytime'},
                          {'method': 'get', 'scope': 'general.daytime.noon'},
                          {'method': 'get', 'scope': 'general.nighttime'}, ]

        unique_scopes = track_chain_object.get_unique_scopes(variable_calls)
        assert sorted(unique_scopes) == sorted(["general", "general.daytime", "general.daytime.noon",
                                                "general.nighttime"])

    def test_get_unique_scopes_no_general(self, track_chain_object):
        variable_calls = [{'method': 'get', 'scope': 'general.daytime'},
                                     {'method': 'get', 'scope': 'general.nighttime'}, ]
        unique_scopes = track_chain_object.get_unique_scopes(variable_calls)
        assert sorted(unique_scopes) == sorted(["general", "general.daytime", "general.nighttime"])

    def test_get_all_dims(self, track_chain_object):
        scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"],
                  "moonlight": ["general", "general.daytime"],
                  "sunlight": ["general"],
                  "test": ["general", "general.daytime"]}
        dims = track_chain_object.get_all_dims(scopes)
        expected_dims = {"another": 4, "moonlight": 2, "sunlight": 1, "test": 2}
        assert dims == expected_dims

    def test_create_track_chain(self, track_chain):
        train_chain_dict = track_chain.create_track_chain()
        assert list(train_chain_dict.keys()) == ["Stage1", "Stage2", "Stage3"]
        assert len(train_chain_dict["Stage1"]) == 3
        assert len(train_chain_dict["Stage2"]) == 3
        assert len(train_chain_dict["Stage3"]) == 3

    def test_control_dict(self, track_chain_object):
        scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"],
                  "moonlight": ["general", "general.daytime"],
                  "sunlight": ["general"],
                  "test": ["general", "general.daytime"]}
        control = track_chain_object.control_dict(scopes)
        expected_control = {"another": {"general": None, "general.daytime": None, "general.daytime.noon": None,
                                        "general.nighttime": None},
                            "moonlight": {"general": None, "general.daytime": None},
                            "sunlight": {"general": None},
                            "test": {"general": None, "general.daytime": None}}
        assert control == expected_control

    def test__create_track_chain(self, track_chain_object):
        control = {'another': {'general': None, 'general.sub': None},
                   'first': {'general': None, 'general.sub': None},
                   'skip': {'general': None, 'general.sub': None}}
        sorted_track_dict = OrderedDict([("another", [{"method": "set", "scope": "general"},
                                                      {"method": "get", "scope": "general"},
                                                      {"method": "get", "scope": "general.sub"}]),
                                         ("first", [{"method": "set", "scope": "general.sub"},
                                                    {"method": "get", "scope": "general.sub"}]),
                                         ("skip", [{"method": "get", "scope": "general.sub"}]),])
        stage = "Stage1"
        track_objects, control = track_chain_object._create_track_chain(control, sorted_track_dict, stage)
        assert len(track_objects) == 2
        assert control["another"]["general"] is not None
        assert control["first"]["general"] is None
        assert control["skip"]["general.sub"] is None

    def test_add_precursor(self, track_chain_object):
        track_objects = []
        tr_obj = TrackObject(["first", "get", "general"], "Stage1")
        prev_obj = TrackObject(["first", "set", "general"], "Stage1")
        assert len(track_chain_object._add_precursor(track_objects, tr_obj, prev_obj)) == 0
        assert tr_obj.precursor[0] == prev_obj

    def test_add_track_object_same_stage(self, track_chain_object):
        track_objects = []
        tr_obj = TrackObject(["first", "get", "general"], "Stage1")
        prev_obj = TrackObject(["first", "set", "general"], "Stage1")
        assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 0

    def test_add_track_object_different_stage(self, track_chain_object):
        track_objects = []
        tr_obj = TrackObject(["first", "get", "general"], "Stage2")
        prev_obj = TrackObject(["first", "set", "general"], "Stage1")
        assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 1
        tr_obj = TrackObject(["first", "get", "general.sub"], "Stage2")
        assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 2

    def test_update_control(self, track_chain_object):
        control = {'another': {'general': None, 'general.sub': None},
                   'first': {'general': None, 'general.sub': None}, }
        variable, scope, tr_obj = "first", "general", 23
        track_chain_object._update_control(control, variable, scope, tr_obj)
        assert control[variable][scope] == tr_obj

    def test_add_set_object(self, track_chain_object):
        track_objects = []
        tr_obj = TrackObject(["first", "set", "general"], "Stage1")
        control_obj = TrackObject(["first", "set", "general"], "Stage1")
        assert len(track_chain_object._add_set_object(track_objects, tr_obj, control_obj)) == 0
        assert len(tr_obj.precursor) == 1
        control_obj = TrackObject(["first", "set", "general"], "Stage0")
        assert len(track_chain_object._add_set_object(track_objects, tr_obj, control_obj)) == 1
        assert len(tr_obj.precursor) == 2

    def test_add_set_object_no_control_obj(self, track_chain_object):
        track_objects = []
        tr_obj = TrackObject(["first", "set", "general"], "Stage1")
        assert len(track_chain_object._add_set_object(track_objects, tr_obj, None)) == 1
        assert tr_obj.precursor is None

    def test_add_get_object_no_new_track_obj(self, track_chain_object):
        track_objects = []
        tr_obj = TrackObject(["first", "get", "general"], "Stage1")
        pre = TrackObject(["first", "set", "general"], "Stage1")
        control = {"testVar": {"general": pre, "general.sub": None}}
        scope, variable = "general", "testVar"
        res = track_chain_object._add_get_object(track_objects, tr_obj, pre, control, scope, variable)
        assert res == ([], False)
        assert pre.successor[0] == tr_obj

    def test_add_get_object_no_control_obj(self, track_chain_object):
        track_objects = []
        tr_obj = TrackObject(["first", "get", "general"], "Stage1")
        pre = TrackObject(["first", "set", "general"], "Stage1")
        control = {"testVar": {"general": pre, "general.sub": None}}
        scope, variable = "general.sub", "testVar"
        res = track_chain_object._add_get_object(track_objects, tr_obj, None, control, scope, variable)
        assert res == ([], False)
        assert pre.successor[0] == tr_obj

    def test_add_get_object_skip_update(self, track_chain_object):
        track_objects = []
        tr_obj = TrackObject(["first", "get", "general"], "Stage1")
        control = {"testVar": {"general": None, "general.sub": None}}
        scope, variable = "general.sub", "testVar"
        res = track_chain_object._add_get_object(track_objects, tr_obj, None, control, scope, variable)
        assert res == ([], True)

    def test_recursive_decent_avail_in_1_up(self, track_chain_object):
        scope = "general.sub"
        expected_pre = TrackObject(["first", "set", "general"], "Stage1")
        control_obj_var = {"general": expected_pre}
        pre = track_chain_object._recursive_decent(scope, control_obj_var)
        assert pre == expected_pre

    def test_recursive_decent_avail_in_2_up(self, track_chain_object):
        scope = "general.sub.sub"
        expected_pre = TrackObject(["first", "set", "general"], "Stage1")
        control_obj_var = {"general": expected_pre, "general.sub": None}
        pre = track_chain_object._recursive_decent(scope, control_obj_var)
        assert pre == expected_pre

    def test_recursive_decent_avail_from_chain(self, track_chain_object):
        scope = "general.sub.sub"
        expected_pre = TrackObject(["first", "set", "general"], "Stage1")
        expected_pre.add_successor(TrackObject(["first", "get", "general.sub"], "Stage1"))
        control_obj_var = {"general": expected_pre, "general.sub": expected_pre.successor[0]}
        pre = track_chain_object._recursive_decent(scope, control_obj_var)
        assert pre == expected_pre

    def test_recursive_decent_avail_from_chain_get(self, track_chain_object):
        scope = "general.sub.sub"
        expected_pre = TrackObject(["first", "get", "general"], "Stage1")
        expected_pre.add_precursor(TrackObject(["first", "set", "general"], "Stage1"))
        control_obj_var = {"general": expected_pre, "general.sub": None}
        pre = track_chain_object._recursive_decent(scope, control_obj_var)
        assert pre == expected_pre

    def test_recursive_decent_avail_from_chain_multiple_get(self, track_chain_object):
        scope = "general.sub.sub"
        expected_pre = TrackObject(["first", "get", "general"], "Stage1")
        start_obj = TrackObject(["first", "set", "general"], "Stage1")
        mid_obj = TrackObject(["first", "get", "general"], "Stage1")
        expected_pre.add_precursor(mid_obj)
        mid_obj.add_precursor(start_obj)
        control_obj_var = {"general": expected_pre, "general.sub": None}
        pre = track_chain_object._recursive_decent(scope, control_obj_var)
        assert pre == expected_pre

    def test_clean_control(self, track_chain_object):
        tr1 = TrackObject(["first", "get", "general"], "Stage1")
        tr2 = TrackObject(["first", "set", "general"], "Stage1")
        tr2.add_precursor(tr1)
        tr3 = TrackObject(["first", "get", "general/sub"], "Stage1")
        tr3.add_precursor(tr1)
        control = {'another': {'general': None, 'general.sub': None},
                   'first': {'general': tr2, 'general.sub': tr3}, }
        control = track_chain_object.clean_control(control)
        expected_control = {'another': {'general': None, 'general.sub': None},
                            'first': {'general': tr2, 'general.sub': None}, }
        assert control == expected_control


class TestTrackPlot:

    @pytest.fixture
    def track_plot_obj(self):
        return object.__new__(TrackPlot)

    @pytest.fixture
    def track_list(self):
        return [{'Stage1': {'test': [{'method': 'set', 'scope': 'general.daytime'},
                                     {'method': 'set', 'scope': 'general'},
                                     {'method': 'get', 'scope': 'general'},
                                     {'method': 'get', 'scope': 'general'},],
                            'another': [{'method': 'set', 'scope': 'general'}]}},
                {'Stage2': {'sunlight': [{'method': 'set', 'scope': 'general'}],
                            'another': [{'method': 'get', 'scope': 'general.daytime'},
                                        {'method': 'set', 'scope': 'general'},
                                        {'method': 'set', 'scope': 'general.daytime'},
                                        {'method': 'get', 'scope': 'general.daytime.noon'},
                                        {'method': 'get', 'scope': 'general.nighttime'},
                                        {'method': 'get', 'scope': 'general.daytime.noon'}]}},
                {'RunEnvironment': {'another': [{'method': 'get', 'scope': 'general.daytime'}],
                                    'test': [{'method': 'get', 'scope': 'general'}],
                                    'moonlight': [{'method': 'set', 'scope': 'general.daytime'}]}}]

    @pytest.fixture
    def scopes(self):
        return {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"],
                "moonlight": ["general", "general.daytime"],
                "sunlight": ["general"],
                "test": ["general", "general.daytime"]}

    @pytest.fixture
    def dims(self):
        return {"another": 4, "moonlight": 2, "sunlight": 1, "test": 2}

    @pytest.fixture
    def track_chain_dict(self, track_list):
        return TrackChain(track_list).create_track_chain()

    @pytest.fixture
    def path(self):
        p = os.path.join(os.path.dirname(__file__), "TestExperiment")
        if not os.path.exists(p):
            os.makedirs(p)
        yield p
        shutil.rmtree(p, ignore_errors=True)

    def test_init(self, path, track_list):
        assert "tracking.pdf" not in os.listdir(path)
        TrackPlot(track_list, plot_folder=path)
        assert "tracking.pdf" in os.listdir(path)

    def test_plot(self):
        pass

    def test_line(self, track_plot_obj):
        h, w = 0.6, 0.65
        track_plot_obj.height = h
        track_plot_obj.width = w
        track_plot_obj.fig, track_plot_obj.ax = plt.subplots()
        assert len(track_plot_obj.ax.lines) == 0
        track_plot_obj.line(start_x=5, end_x=6, y=2)
        assert len(track_plot_obj.ax.lines) == 2
        pos_x, pos_y = np.array([5 + w, 6]), np.ones((2, )) * (2 + h / 2)
        assert track_plot_obj.ax.lines[0]._color == "white"
        assert track_plot_obj.ax.lines[0]._linewidth == 2.5
        assert track_plot_obj.ax.lines[1]._color == "darkgrey"
        assert track_plot_obj.ax.lines[1]._linewidth == 1.4
        assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true()
        assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true()

    def test_step(self, track_plot_obj):
        x_int, h, w = 0.5, 0.6, 0.65
        track_plot_obj.space_intern_x = x_int
        track_plot_obj.height = h
        track_plot_obj.width = w
        track_plot_obj.fig, track_plot_obj.ax = plt.subplots()
        assert len(track_plot_obj.ax.lines) == 0
        track_plot_obj.step(start_x=5, end_x=6, start_y=2, end_y=3)
        assert len(track_plot_obj.ax.lines) == 2
        pos_x = np.array([5 + w, 6 - x_int / 2, 6 - x_int / 2, 6])
        pos_y = np.array([2 + h / 2, 2 + h / 2, 3 + h / 2, 3 + h / 2])
        assert track_plot_obj.ax.lines[0]._color == "white"
        assert track_plot_obj.ax.lines[0]._linewidth == 2.5
        assert track_plot_obj.ax.lines[1]._color == "black"
        assert track_plot_obj.ax.lines[1]._linewidth == 1.4
        assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true()
        assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true()

    def test_rect(self, track_plot_obj):
        h, w = 0.5, 0.6
        track_plot_obj.height = h
        track_plot_obj.width = w
        track_plot_obj.fig, track_plot_obj.ax = plt.subplots()
        assert len(track_plot_obj.ax.artists) == 0
        assert len(track_plot_obj.ax.texts) == 0
        track_plot_obj.rect(x=4, y=2)
        assert len(track_plot_obj.ax.artists) == 1
        assert len(track_plot_obj.ax.texts) == 1
        track_plot_obj.ax.artists[0].xy == (4, 2)
        track_plot_obj.ax.artists[0]._height == h
        track_plot_obj.ax.artists[0]._width == w
        track_plot_obj.ax.artists[0]._original_facecolor == "orange"
        track_plot_obj.ax.texts[0].xy == (4 + w / 2, 2 + h / 2)
        track_plot_obj.ax.texts[0]._color == "w"
        track_plot_obj.ax.texts[0]._text == "get"
        track_plot_obj.rect(x=4, y=2, method="set")
        assert len(track_plot_obj.ax.artists) == 2
        assert len(track_plot_obj.ax.texts) == 2
        track_plot_obj.ax.artists[0]._original_facecolor == "lightblue"
        track_plot_obj.ax.texts[0]._text == "set"



    def test_set_ypos_anchor(self, track_plot_obj, scopes, dims):
        assert not hasattr(track_plot_obj, "y_pos")
        assert not hasattr(track_plot_obj, "anchor")
        y_int, y_ext, h = 0.5, 0.7, 0.6
        track_plot_obj.space_intern_y = y_int
        track_plot_obj.height = h
        track_plot_obj.space_extern_y = y_ext
        track_plot_obj.set_ypos_anchor(scopes, dims)
        d_y = 0 - sum([factor * (y_int + h) + y_ext - y_int for factor in dims.values()])
        expected_anchor = (d_y + sum(dims.values()), h + y_ext + sum(dims.values()))
        assert np.testing.assert_array_almost_equal(track_plot_obj.anchor, expected_anchor) is None
        assert track_plot_obj.y_pos["another"]["general"] == sum(dims.values())
        assert track_plot_obj.y_pos["another"]["general.daytime"] == sum(dims.values()) - (h + y_int)
        assert track_plot_obj.y_pos["another"]["general.daytime.noon"] == sum(dims.values()) - 2 * (h + y_int)

    def test_plot_track_chain(self):
        pass

    def test_add_variable_names(self):
        pass

    def test_add_stages(self):
        pass

    def test_create_track_chain_plot_run_env(self):
        pass

    def test_set_lims(self, track_plot_obj):
        track_plot_obj.x_max = 10
        track_plot_obj.space_intern_x = 0.5
        track_plot_obj.width = 0.4
        track_plot_obj.anchor = np.array((0.1, 12.5))
        track_plot_obj.fig, track_plot_obj.ax = plt.subplots()
        assert track_plot_obj.ax.get_ylim() == (0, 1)  # matplotlib default
        assert track_plot_obj.ax.get_xlim() == (0, 1)  # matplotlib default
        track_plot_obj.set_lims()
        assert track_plot_obj.ax.get_ylim() == (0.1, 12.5)
        assert track_plot_obj.ax.get_xlim() == (0, 10+0.5+0.4)