__author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2019-12-17'

import os
import logging
import math
import warnings
from src import helpers
from src.helpers import TimeTracking

import numpy as np
import xarray as xr
import pandas as pd

import matplotlib
import seaborn as sns
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.backends.backend_pdf import PdfPages

from typing import Dict, List

logging.getLogger('matplotlib').setLevel(logging.WARNING)


def plot_monthly_summary(stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None,
                         plot_folder: str = "."):
    """
    Show a monthly summary over all stations for each lead time ("ahead") as box and whiskers plot. The plot is saved
    in data_path with name monthly_summary_box_plot.pdf and 500dpi resolution.
    :param stations: all stations to plot
    :param data_path: path, where the data is located
    :param name: full name of the local files with a % as placeholder for the station name
    :param target_var: display name of the target variable on plot's axis
    :param window_lead_time: lead time to plot, if window_lead_time is higher than the available lead time or not given
        the maximum lead time from data is used. (default None -> use maximum lead time from data).
    :param plot_folder: path to save the plot (default: current directory)
    """
    logging.debug("run plot_monthly_summary()")
    forecasts = None

    for station in stations:
        logging.debug(f"... preprocess station {station}")
        file_name = os.path.join(data_path, name % station)
        data = xr.open_dataarray(file_name)

        data_cnn = data.sel(type="CNN").squeeze()
        data_cnn.coords["ahead"].values = [f"{days}d" for days in data_cnn.coords["ahead"].values]

        data_orig = data.sel(type="orig", ahead=1).squeeze()
        data_orig.coords["ahead"] = "orig"

        data_concat = xr.concat([data_orig, data_cnn], dim="ahead")
        data_concat = data_concat.drop("type")

        data_concat.index.values = data_concat.index.values.astype("datetime64[M]").astype(int) % 12 + 1
        data_concat = data_concat.clip(min=0)

        forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat

    ahead_steps = len(forecasts.ahead)
    if window_lead_time is None:
        window_lead_time = ahead_steps
    window_lead_time = min(ahead_steps, window_lead_time)

    forecasts = forecasts.to_dataset(name='values').to_dask_dataframe()
    logging.debug("... start plotting")
    ax = sns.boxplot(x='index', y='values', hue='ahead', data=forecasts.compute(), whis=1.,
                     palette=[matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d",
                                                                                     window_lead_time).as_hex(),
                     flierprops={'marker': '.', 'markersize': 1}, showmeans=True,
                     meanprops={'markersize': 1, 'markeredgecolor': 'k'})
    ax.set(xlabel='month', ylabel=f'{target_var}')
    plt.tight_layout()
    plot_name = os.path.join(os.path.abspath(plot_folder), 'monthly_summary_box_plot.pdf')
    logging.debug(f"... save plot to {plot_name}")
    plt.savefig(plot_name, dpi=500)
    plt.close('all')


def plot_station_map(generators: Dict, plot_folder: str = "."):
    """
    Plot geographical overview of all used stations. Different data sets can be colorised by its key in the input
    dictionary generators. The key represents the color to plot on the map. Currently, there is only a white background,
    but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is saved under
    plot_path with the name station_map.pdf
    :param generators: dictionary with the plot color of each data set as key and the generator containing all stations
        as value.
    :param plot_folder: path to save the plot (default: current directory)
    """
    logging.debug("run station_map()")
    fig = plt.figure(figsize=(10, 5))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    ax.set_extent([0, 20, 42, 58], crs=ccrs.PlateCarree())
    ax.add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor='black')
    ax.add_feature(cfeature.LAKES.with_scale("50m"))
    ax.add_feature(cfeature.OCEAN.with_scale("50m"))
    ax.add_feature(cfeature.RIVERS.with_scale("10m"))
    ax.add_feature(cfeature.BORDERS.with_scale("10m"), facecolor='none', edgecolor='black')

    if generators is not None:
        for color, gen in generators.items():
            for k, v in enumerate(gen):
                station_coords = gen.get_data_generator(k).meta.loc[['station_lon', 'station_lat']]
                # station_names = gen.get_data_generator(k).meta.loc[['station_id']]
                IDx, IDy = float(station_coords.loc['station_lon'].values), float(
                    station_coords.loc['station_lat'].values)
                ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree())

    plot_name = os.path.join(os.path.abspath(plot_folder), 'station_map.pdf')
    logging.debug(f"... save plot to {plot_name}")
    plt.savefig(plot_name, dpi=500)
    plt.close('all')


def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_window: int = 3, ref_name: str = 'orig',
                               pred_name: str = 'CNN', season: str = "", forecast_path: str = None,
                               plot_name_affix: str = "", units: str = "ppb"):
    """
    This plot was originally taken from Murphy, Brown and Chen (1989):
    https://journals.ametsoc.org/doi/pdf/10.1175/1520-0434%281989%29004%3C0485%3ADVOTF%3E2.0.CO%3B2

    :param stations: stations to include in the plot (forecast data needs to be available already)
    :param plot_folder: path to save the plot (default: current directory)
    :param rolling_window: the rolling window mean will smooth the plot appearance (no smoothing in bin calculation,
        this is only a cosmetic step, default: 3)
    :param ref_name: name of the reference data series
    :param pred_name: name of the investigated data series
    :param season: season name to highlight if not empty
    :param forecast_path: path to save the plot file
    :param plot_name_affix: name to specify this plot (e.g. 'cali-ref', default: '')
    :param units: units of the forecasted values (default: ppb)
    """
    time = TimeTracking()
    logging.debug(f"started plot_conditional_quantiles()")
    # ignore warnings if nans appear in quantile grouping
    warnings.filterwarnings("ignore", message="All-NaN slice encountered")
    # ignore warnings if mean is calculated on nans
    warnings.filterwarnings("ignore", message="Mean of empty slice")
    # ignore warnings for y tick = 0 on log scale (instead of 0.00001 or similar)
    warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.")

    def load_data():
        logging.debug("... load data")
        data_collector = []
        for station in stations:
            file = os.path.join(forecast_path, f"forecasts_{station}_test.nc")
            data_tmp = xr.open_dataarray(file)
            data_collector.append(data_tmp.loc[:, :, ['CNN', 'orig', 'OLS']].assign_coords(station=station))
        return xr.concat(data_collector, dim='station').transpose('index', 'type', 'ahead', 'station')

    def segment_data(data):
        logging.debug("... segment data")
        # combine index and station to multi index
        data = data.stack(z=['index', 'station'])
        # replace multi index by simple position index (order is not relevant anymore)
        data.coords['z'] = range(len(data.coords['z']))
        # segment data of pred_name into bins
        data.loc[pred_name, ...] = data.loc[pred_name, ...].to_pandas().T.apply(pd.cut, bins=bins,
                                                                                labels=bins[1:]).T.values
        return data

    def create_quantile_panel(data, q):
        logging.debug("... create quantile panel")
        # create empty xarray with dims: time steps ahead, quantiles, bin index (numbers create in previous step)
        quantile_panel = xr.DataArray(np.full([data.ahead.shape[0], len(q), bins[1:].shape[0]], np.nan),
                                      coords=[data.ahead, q, bins[1:]], dims=['ahead', 'quantiles', 'categories'])
        # ensure that the coordinates are in the right order
        quantile_panel = quantile_panel.transpose('ahead', 'quantiles', 'categories')
        # calculate for each bin of the pred_name data the quantiles of the ref_name data
        for bin in bins[1:]:
            mask = (data.loc[pred_name, ...] == bin)
            quantile_panel.loc[..., bin] = data.loc[ref_name, ...].where(mask).quantile(q, dim=['z']).T

        return quantile_panel

    def labels(plot_type, data_unit="ppb"):
        names = (f"forecast concentration (in {data_unit})", f"observed concentration (in {data_unit})")
        if plot_type == "orig":
            return names
        else:
            return names[::-1]

    xlabel, ylabel = labels(ref_name, units)

    opts = {"q": [.1, .25, .5, .75, .9], "linetype": [':', '-.', '--', '-.', ':'],
            "legend": ['.10th and .90th quantile', '.25th and .75th quantile', '.50th quantile', 'reference 1:1'],
            "xlabel": xlabel, "ylabel": ylabel}

    # set name and path of the plot
    base_name = "conditional_quantiles"
    def add_affix(x): return f"_{x}" if len(x) > 0 else ""
    plot_name = f"{base_name}{add_affix(season)}{add_affix(plot_name_affix)}_plot.pdf"
    plot_path = os.path.join(os.path.abspath(plot_folder), plot_name)

    # check forecast path
    if forecast_path is None:
        raise ValueError("Forecast path is not given but required.")

    # load data and set data bins
    orig_data = load_data()
    bins = np.arange(0, math.ceil(orig_data.max().max()) + 1, 1).astype(int)
    segmented_data = segment_data(orig_data)
    quantile_panel = create_quantile_panel(segmented_data, q=opts["q"])

    # init pdf output
    pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
    logging.debug(f"... plot path is {plot_path}")

    # create plot for each time step ahead
    y2_max = 0
    for iteration, d in enumerate(segmented_data.ahead):
        logging.debug(f"... plotting {d.values} time step(s) ahead")
        # plot smoothed lines with rolling mean
        smooth_data = quantile_panel.loc[d, ...].rolling(categories=rolling_window, center=True).mean().to_pandas().T
        ax = smooth_data.plot(style=opts["linetype"], color='black', legend=False)
        ax2 = ax.twinx()
        # add reference line
        ax.plot([0, bins.max()], [0, bins.max()], color='k', label='reference 1:1', linewidth=.8)
        # add histogram of the segmented data (pred_name)
        handles, labels = ax.get_legend_handles_labels()
        segmented_data.loc[pred_name, d, :].to_pandas().hist(bins=bins, ax=ax2, color='k', alpha=.3, grid=False,
                                                             rwidth=1)
        # add legend
        plt.legend(handles[:3] + [handles[-1]], opts["legend"], loc='upper left', fontsize='large')
        # adjust limits and set labels
        ax.set(xlim=(0, bins.max()), ylim=(0, bins.max()))
        ax.set_xlabel(opts["xlabel"], fontsize='x-large')
        ax.tick_params(axis='x', which='major', labelsize=15)
        ax.set_ylabel(opts["ylabel"], fontsize='x-large')
        ax.tick_params(axis='y', which='major', labelsize=15)
        ax2.yaxis.label.set_color('gray')
        ax2.tick_params(axis='y', colors='gray')
        ax2.yaxis.labelpad = -15
        ax2.set_yscale('log')
        if iteration == 0:
            y2_max = ax2.get_ylim()[1] + 100
        ax2.set(ylim=(0, y2_max * 10 ** 8), yticks=np.logspace(0, 4, 5))
        ax2.set_ylabel('              sample size', fontsize='x-large')
        ax2.tick_params(axis='y', which='major', labelsize=15)
        # set title and save current figure
        title = f"{d.values} time step(s) ahead{f' ({season})' if len(season) > 0 else ''}"
        plt.title(title)
        pdf_pages.savefig()
    # close all open figures / plots
    pdf_pages.close()
    plt.close('all')
    logging.info(f"plot_conditional_quantiles() finished after {time}")


def plot_climatological_skill_score(data: Dict, plot_folder: str = ".", score_only: bool = True,
                                    extra_name_tag: str = "", model_setup: str = ""):
    """
    Create plot of climatological skill score after Murphy (1988) as box plot over all stations. A forecast time step
    (called "ahead") is separately shown to highlight the differences for each prediction time step. Either each single
    term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed (score_only=True,
    default). Y-axis is adjusted following the data and not hard coded. The plot is saved under plot_folder path with
    name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi.
    :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms.
    :param plot_folder: path to save the plot (default: current directory)
    :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True)
    :param extra_name_tag: additional tag that can be included in the plot name (default "")
    :param model_setup: architecture type (default "CNN")
    """
    logging.debug("run plot_climatological_skill_score()")
    data = helpers.dict_to_xarray(data, "station")
    labels = [str(i) + "d" for i in data.coords["ahead"].values]
    fig, ax = plt.subplots()
    if score_only:
        data = data.loc[:, ["CASE I", "CASE II", "CASE III", "CASE IV"], :]
        lab_add = ''
    else:
        fig.set_size_inches(11.7, 8.27)
        lab_add = "terms and "
    data = data.to_dataframe("data").reset_index(level=[0, 1, 2])
    sns.boxplot(x="terms", y="data", hue="ahead", data=data, ax=ax, whis=1., palette="Blues_d", showmeans=True,
                meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."})
    ax.axhline(y=0, color="grey", linewidth=.5)
    ax.set(ylabel=f"{lab_add}skill score", xlabel="", title="summary of all stations")
    handles, _ = ax.get_legend_handles_labels()
    ax.legend(handles, labels)
    plt.tight_layout()
    plot_name = os.path.join(plot_folder, f"skill_score_clim_{extra_name_tag}{model_setup}.pdf")
    logging.debug(f"... save plot to {plot_name}")
    plt.savefig(plot_name, dpi=500)
    plt.close('all')


def plot_competitive_skill_score(data: pd.DataFrame, plot_folder=".", model_setup="CNN"):
    """
    Create competitive skill score for the given model setup and the reference models ordinary least squared ("ols") and
    the persistence forecast ("persi") for all lead times ("ahead"). The plot is saved under plot_folder with the name
    skill_score_competitive_{model_setup}.pdf and resolution of 500dpi.
    :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre-
        calculated comparisons for cnn, persistence and ols.
    :param plot_folder: path to save the plot (default: current directory)
    :param model_setup: architecture type (default "CNN")
    """
    logging.debug("run plot_general_skill_score()")

    data = pd.concat(data, axis=0)
    data = xr.DataArray(data, dims=["stations", "ahead"]).unstack("stations")
    data = data.rename({"stations_level_0": "stations", "stations_level_1": "comparison"})
    data = data.to_dataframe("data").unstack(level=1).swaplevel()
    data.columns = data.columns.levels[1]

    labels = [str(i) + "d" for i in data.index.levels[1].values]
    data = data.stack(level=0).reset_index(level=2, drop=True).reset_index(name="data")

    fig, ax = plt.subplots()
    sns.boxplot(x="comparison", y="data", hue="ahead", data=data, whis=1., ax=ax, palette="Blues_d", showmeans=True,
                meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
                order=["cnn-persi", "ols-persi", "cnn-ols"])
    ax.axhline(y=0, color="grey", linewidth=.5)
    ax.set(ylabel="skill score", xlabel="competing models", title="summary of all stations",
           ylim=(np.min([0, helpers.float_round(data.min()[2], 2) - 0.1]), helpers.float_round(data.max()[2], 2) + 0.1))
    handles, _ = ax.get_legend_handles_labels()
    ax.legend(handles, labels)
    plt.tight_layout()
    plot_name = os.path.join(plot_folder, f"skill_score_competitive_{model_setup}.pdf")
    logging.debug(f"... save plot to {plot_name}")
    plt.savefig(plot_name, dpi=500)
    plt.close()