import os

import keras
import pytest

from mlair.model_modules.keras_extensions import LearningRateDecay
from mlair.plotting.training_monitoring import PlotModelLearningRate, PlotModelHistory


@pytest.fixture
def path():
    p = os.path.join(os.path.dirname(__file__), "TestExperiment")
    if not os.path.exists(p):
        os.makedirs(p)
    return p


class TestPlotModelHistory:

    @pytest.fixture
    def default_history(self):
        hist = keras.callbacks.History()
        hist.epoch = [0, 1]
        hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
                        'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
                        'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
                        'loss': [0.6795708956961347, 0.45963566494176616],
                        'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
                        'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
        return hist

    @pytest.fixture
    def history_additional_loss(self):
        hist = keras.callbacks.History()
        hist.epoch = [0, 1]
        hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
                        'test_loss': [0.595368885413389, 0.530547587585537],
                        'loss': [0.6795708956961347, 0.45963566494176616],
                        'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
                        'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
        return hist

    @pytest.fixture
    def history_with_main(self, default_history):
        default_history.history["main_val_loss"] = [0.5586272982587484, 0.45712877659670287]
        default_history.history["main_loss"] = [0.6795708956961347, 0.45963566494176616]
        return default_history

    @pytest.fixture
    def no_init(self):
        return object.__new__(PlotModelHistory)

    def test_get_plot_metric(self, no_init, default_history):
        history = default_history.history
        metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=False)
        assert metric == "loss"
        metric = no_init._get_plot_metric(history, plot_metric="mean_squared_error", main_branch=False)
        assert metric == "mean_squared_error"

    def test_get_plot_metric_short_metric(self, no_init, default_history):
        history = default_history.history
        metric = no_init._get_plot_metric(history, plot_metric="mse", main_branch=False)
        assert metric == "mean_squared_error"
        metric = no_init._get_plot_metric(history, plot_metric="mae", main_branch=False)
        assert metric == "mean_absolute_error"

    def test_get_plot_metric_main_branch(self, no_init, history_with_main):
        history = history_with_main.history
        metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=True)
        assert metric == "main_loss"

    def test_filter_columns(self, no_init):
        no_init._plot_metric = "loss"
        res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None})
        assert res == ['another_loss']
        no_init._plot_metric = "mean_squared_error"
        res = no_init._filter_columns({'mean_squared_error': None, 'another_loss': None, 'val_mean_squared_error': None,
                                       'wrong': None})
        assert res == []

    def test_plot_from_hist_obj(self, default_history, path):
        assert "hist_obj.pdf" not in os.listdir(path)
        PlotModelHistory(os.path.join(path, "hist_obj.pdf"), default_history)
        assert "hist_obj.pdf" in os.listdir(path)

    def test_plot_from_hist_dict(self, default_history, path):
        assert "hist_dict.pdf" not in os.listdir(path)
        PlotModelHistory(os.path.join(path, "hist_dict.pdf"), default_history.history)
        assert "hist_dict.pdf" in os.listdir(path)

    def test_plot_additional_loss(self, history_additional_loss, path):
        assert "hist_additional.pdf" not in os.listdir(path)
        PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_additional_loss)
        assert "hist_additional.pdf" in os.listdir(path)


class TestPlotModelLearningRate:

    @pytest.fixture
    def learning_rate(self):
        lr_obj = LearningRateDecay()
        lr_obj.lr = {"lr": [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.0094, 0.0094, 0.0094, 0.0094,
                            0.0094, 0.0094, 0.0094, 0.0094, 0.0094, 0.0094]}
        return lr_obj

    def test_plot_from_lr_obj(self, learning_rate, path):
        assert "lr_obj.pdf" not in os.listdir(path)
        PlotModelLearningRate(os.path.join(path, "lr_obj.pdf"), learning_rate)
        assert "lr_obj.pdf" in os.listdir(path)

    def test_plot_from_lr_dict(self, learning_rate, path):
        assert "lr_dict.pdf" not in os.listdir(path)
        PlotModelLearningRate(os.path.join(path, "lr_dict.pdf"), learning_rate.lr)
        assert "lr_dict.pdf" in os.listdir(path)