from scipy import stats

from src.run_modules.run_environment import RunEnvironment

__author__ = 'Lukas Leufen, Felix Kleinert'
__date__ = '2019-10-23'

import numpy as np
import xarray as xr
import pandas as pd
from typing import Union, Tuple


Data = Union[xr.DataArray, pd.DataFrame]


def apply_inverse_transformation(data, mean, std=None, method="standardise"):
    if method == 'standardise':  # pragma: no branch
        return standardise_inverse(data, mean, std)
    elif method == 'centre':  # pragma: no branch
        return centre_inverse(data, mean)
    elif method == 'normalise':  # pragma: no cover
        # use min/max of data or given min/max
        raise NotImplementedError
    else:
        raise NotImplementedError


def standardise(data: Data, dim: Union[str, int]) -> Tuple[Data, Data, Data]:
    """
    This function standardises a xarray.dataarray (along dim) or pandas.DataFrame (along axis) with mean=0 and std=1
    :param data:
    :param string/int dim:
            | for xarray.DataArray as string: name of dimension which should be standardised
            | for pandas.DataFrame as int: axis of dimension which should be standardised
    :return: xarray.DataArrays or pandas.DataFrames:
            #. mean: Mean of data
            #. std: Standard deviation of data
            #. data: Standardised data
    """
    return data.mean(dim), data.std(dim), (data - data.mean(dim)) / data.std(dim)


def standardise_inverse(data: Data, mean: Data, std: Data) -> Data:
    """
    This is the inverse function of `standardise` and therefore vanishes the standardising.
    :param data:
    :param mean:
    :param std:
    :return:
    """
    return data * std + mean


def standardise_apply(data: Data, mean: Data, std: Data) -> Data:
    """
    This applies `standardise` on data using given mean and std.
    :param data:
    :param mean:
    :param std:
    :return:
    """
    return (data - mean) / std


def centre(data: Data, dim: Union[str, int]) -> Tuple[Data, None, Data]:
    """
    This function centres a xarray.dataarray (along dim) or pandas.DataFrame (along axis) to mean=0
    :param data:
    :param string/int dim:
            | for xarray.DataArray as string: name of dimension which should be standardised
            | for pandas.DataFrame as int: axis of dimension which should be standardised
    :return: xarray.DataArrays or pandas.DataFrames:
            #. mean: Mean of data
            #. std: Standard deviation of data
            #. data: Standardised data
    """
    return data.mean(dim), None, data - data.mean(dim)


def centre_inverse(data: Data, mean: Data) -> Data:
    """
    This function is the inverse function of `centre` and therefore adds the given values of mean to the data.
    :param data:
    :param mean:
    :return:
    """
    return data + mean


def centre_apply(data: Data, mean: Data) -> Data:
    """
    This applies `centre` on data using given mean and std.
    :param data:
    :param mean:
    :param std:
    :return:
    """
    return data - mean


def mean_squared_error(a, b):
    return np.square(a - b).mean()


class SkillScores:

    def __init__(self, internal_data):
        self.internal_data = internal_data

    def skill_scores(self, window_lead_time):
        ahead_names = list(range(1, window_lead_time + 1))
        skill_score = pd.DataFrame(index=['cnn-persi', 'ols-persi', 'cnn-ols'])
        for iahead in ahead_names:
            data = self.internal_data.sel(ahead=iahead)
            skill_score[iahead] = [self.general_skill_score(data, forecast_name="CNN", reference_name="persi"),
                                   self.general_skill_score(data, forecast_name="OLS", reference_name="persi"),
                                   self.general_skill_score(data, forecast_name="CNN", reference_name="OLS")]
        return skill_score

    def climatological_skill_scores(self, external_data, window_lead_time):
        ahead_names = list(range(1, window_lead_time + 1))

        all_terms = ['AI', 'AII', 'AIII', 'AIV', 'BI', 'BII', 'BIV', 'CI', 'CIV', 'CASE I', 'CASE II', 'CASE III',
                     'CASE IV']
        skill_score = xr.DataArray(np.full((len(all_terms), len(ahead_names)), np.nan), coords=[all_terms, ahead_names],
                                   dims=['terms', 'ahead'])

        for iahead in ahead_names:

            data = self.internal_data.sel(ahead=iahead)

            skill_score.loc[["CASE I", "AI", "BI", "CI"], iahead] = np.stack(self._climatological_skill_score(
                data, mu_type=1, forecast_name="CNN").values.flatten())

            skill_score.loc[["CASE II", "AII", "BII"], iahead] = np.stack(self._climatological_skill_score(
                data, mu_type=2, forecast_name="CNN").values.flatten())

            if external_data is not None:
                skill_score.loc[["CASE III", "AIII"], iahead] = np.stack(self._climatological_skill_score(
                    data, mu_type=3, forecast_name="CNN",
                    external_data=external_data).values.flatten())

                skill_score.loc[["CASE IV", "AIV", "BIV", "CIV"], iahead] = np.stack(self._climatological_skill_score(
                    data, mu_type=4, forecast_name="CNN",
                    external_data=external_data).values.flatten())

        return skill_score

    def _climatological_skill_score(self, data, mu_type=1, observation_name="obs", forecast_name="CNN", external_data=None):
        kwargs = {"external_data": external_data} if external_data is not None else {}
        return self.__getattribute__(f"skill_score_mu_case_{mu_type}")(data, observation_name, forecast_name, **kwargs)

    @staticmethod
    def general_skill_score(data, observation_name="obs", forecast_name="CNN", reference_name="persi"):
        data = data.dropna("index")
        observation = data.sel(type=observation_name)
        forecast = data.sel(type=forecast_name)
        reference = data.sel(type=reference_name)
        mse = mean_squared_error
        skill_score = 1 - mse(observation, forecast) / mse(observation, reference)
        return skill_score.values

    @staticmethod
    def skill_score_pre_calculations(data, observation_name, forecast_name):

        data = data.loc[..., [observation_name, forecast_name]].drop("ahead")
        data = data.dropna("index")

        mean = data.mean("index")
        sigma = np.sqrt(data.var("index"))
        # r, p = stats.spearmanr(data.loc[..., [forecast_name, observation_name]])
        r, p = stats.pearsonr(data.loc[..., forecast_name], data.loc[..., observation_name])

        AI = np.array(r ** 2)
        BI = ((r - (sigma.loc[..., forecast_name] / sigma.loc[..., observation_name])) ** 2).values
        CI = (((mean.loc[..., forecast_name] - mean.loc[..., observation_name]) / sigma.loc[
            ..., observation_name]) ** 2).values

        suffix = {"mean": mean, "sigma": sigma, "r": r, "p": p}
        return AI, BI, CI, data, suffix

    def skill_score_mu_case_1(self, data, observation_name="obs", forecast_name="CNN"):
        AI, BI, CI, data, _ = self.skill_score_pre_calculations(data, observation_name, forecast_name)
        skill_score = np.array(AI - BI - CI)
        return pd.DataFrame({"skill_score": [skill_score], "AI": [AI], "BI": [BI], "CI": [CI]}).to_xarray().to_array()

    def skill_score_mu_case_2(self, data, observation_name="obs", forecast_name="CNN"):
        AI, BI, CI, data, suffix = self.skill_score_pre_calculations(data, observation_name, forecast_name)
        monthly_mean = self.create_monthly_mean_from_daily_data(data)
        data = xr.concat([data, monthly_mean], dim="type")
        sigma = suffix["sigma"]
        sigma_monthly = np.sqrt(monthly_mean.var())
        # r, p = stats.spearmanr(data.loc[..., [observation_name, observation_name + "X"]])
        r, p = stats.pearsonr(data.loc[..., observation_name], data.loc[..., observation_name + "X"])
        AII = np.array(r ** 2)
        BII = ((r - sigma_monthly / sigma.loc[observation_name]) ** 2).values
        skill_score = np.array((AI - BI - CI - AII + BII) / (1 - AII + BII))
        return pd.DataFrame({"skill_score": [skill_score], "AII": [AII], "BII": [BII]}).to_xarray().to_array()

    def skill_score_mu_case_3(self, data, observation_name="obs", forecast_name="CNN", external_data=None):
        AI, BI, CI, data, suffix = self.skill_score_pre_calculations(data, observation_name, forecast_name)
        mean, sigma = suffix["mean"], suffix["sigma"]
        AIII = (((external_data.mean().values - mean.loc[observation_name]) / sigma.loc[observation_name])**2).values
        skill_score = np.array((AI - BI - CI + AIII) / 1 + AIII)
        return pd.DataFrame({"skill_score": [skill_score], "AIII": [AIII]}).to_xarray().to_array()

    def skill_score_mu_case_4(self, data, observation_name="obs", forecast_name="CNN", external_data=None):
        AI, BI, CI, data, suffix = self.skill_score_pre_calculations(data, observation_name, forecast_name)
        monthly_mean_external = self.create_monthly_mean_from_daily_data(external_data, columns=data.type.values, index=data.index)
        data = xr.concat([data, monthly_mean_external], dim="type")
        mean, sigma = suffix["mean"], suffix["sigma"]
        monthly_mean_external = self.create_monthly_mean_from_daily_data(external_data, columns=data.type.values)
        mean_external = monthly_mean_external.mean()
        sigma_external = np.sqrt(monthly_mean_external.var())

        # r_mu, p_mu = stats.spearmanr(data.loc[..., [observation_name, observation_name+'X']])
        r_mu, p_mu = stats.pearsonr(data.loc[..., observation_name], data.loc[..., observation_name + "X"])

        AIV = np.array(r_mu**2)
        BIV = ((r_mu - sigma_external / sigma.loc[observation_name])**2).values
        CIV = (((mean_external - mean.loc[observation_name]) / sigma.loc[observation_name])**2).values
        skill_score = np.array((AI - BI - CI - AIV + BIV + CIV) / (1 - AIV + BIV + CIV))
        return pd.DataFrame({"skill_score": [skill_score], "AIV": [AIV], "BIV": [BIV], "CIV": CIV}).to_xarray().to_array()

    @staticmethod
    def create_monthly_mean_from_daily_data(data, columns=None, index=None):
        if columns is None:
            columns = data.type.values
        if index is None:
            index = data.index
        coordinates = [index, [v + "X" for v in list(columns)]]
        empty_data = np.full((len(index), len(columns)), np.nan)
        monthly_mean = xr.DataArray(empty_data, coords=coordinates, dims=["index", "type"])
        mu = data.groupby("index.month").mean()

        for month in mu.month:
            monthly_mean[monthly_mean.index.dt.month == month, :] = mu[mu.month == month].values

        return monthly_mean