__author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2019-12-17'

import os
import logging
import math
import warnings

import numpy as np
import xarray as xr
import pandas as pd

import matplotlib
import seaborn as sns
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.backends.backend_pdf import PdfPages

logging.getLogger('matplotlib').setLevel(logging.WARNING)


def plot_monthly_summary(stations, data_path, name: str, window_lead_time, target_var, plot_folder="."):

    logging.debug("run plot_monthly_summary()")
    forecasts = None

    for station in stations:
        logging.debug(f"... preprocess station {station}")
        file_name = os.path.join(data_path, name % station)
        data = xr.open_dataarray(file_name)

        data_cnn = data.sel(type="CNN").squeeze()
        data_cnn.coords["ahead"].values = [f"{days}d" for days in data_cnn.coords["ahead"].values]

        data_orig = data.sel(type="orig", ahead=1).squeeze()
        data_orig.coords["ahead"] = "orig"

        data_concat = xr.concat([data_orig, data_cnn], dim="ahead")
        data_concat = data_concat.drop("type")

        data_concat.index.values = data_concat.index.values.astype("datetime64[M]").astype(int) %12 +1
        data_concat = data_concat.clip(min=0)

        forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat

    forecasts = forecasts.to_dataset(name='values').to_dask_dataframe()
    logging.debug("... start plotting")
    ax = sns.boxplot(x='index', y='values',  hue='ahead', data=forecasts.compute(), whis=1.,
                     palette=[matplotlib.colors.cnames["green"]]+sns.color_palette("Blues_d", window_lead_time).as_hex(),
                     flierprops={'marker': '.', 'markersize': 1},
                     showmeans=True, meanprops={'markersize': 1, 'markeredgecolor': 'k'})
    ax.set(xlabel='month', ylabel=f'{target_var}')
    plt.tight_layout()
    plot_path = os.path.join(os.path.abspath(plot_folder), 'test_monthly_box.pdf')
    logging.debug(f"... save plot to {plot_path}")
    plt.savefig(plot_path)
    plt.close('all')


def plot_climsum_boxplot():
    return


def station_map(generators, plot_folder="."):

    logging.debug("run station_map()")
    fig = plt.figure(figsize=(10, 5))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    ax.set_extent([0, 20, 42, 58], crs=ccrs.PlateCarree())
    ax.add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor='black')
    ax.add_feature(cfeature.LAKES.with_scale("50m"))
    ax.add_feature(cfeature.OCEAN.with_scale("50m"))
    ax.add_feature(cfeature.RIVERS.with_scale("10m"))
    ax.add_feature(cfeature.BORDERS.with_scale("10m"), facecolor='none', edgecolor='black')

    if generators is not None:
        for color, gen in generators.items():
            for k, v in enumerate(gen):
                station_coords = gen.get_data_generator(k).meta.loc[['station_lon', 'station_lat']]
                station_names = gen.get_data_generator(k).meta.loc[['station_id']]
                IDx, IDy = float(station_coords.loc['station_lon'].values), float(station_coords.loc['station_lat'].values)
                ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree())

    plot_path = os.path.join(os.path.abspath(plot_folder), 'test_map_plot.pdf')
    plt.savefig(plot_path)
    plt.close('all')


def plot_conditional_quantiles(stations, plot_folder=".", q=None, rolling_window=3, ref_name='orig', pred_name='CNN', season="", forecast_path=None):
    # ignore warnings if nans appear in quantile grouping
    warnings.filterwarnings("ignore", message="All-NaN slice encountered")
    # ignore warnings if mean is calculated on nans
    warnings.filterwarnings("ignore", message="Mean of empty slice")
    # ignore warnings for y tick = 0 on log scale (instead of 0.00001 or similar)
    warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.")

    plot_name = f"test_conditional_quantiles{f'_{season}' if len(season) > 0 else ''}"
    plot_path = os.path.join(os.path.abspath(plot_folder), f"{plot_name}_plot.pdf")

    if q is None:
        q = [.1, .25, .5, .75, .9]

    if forecast_path is None:
        raise ValueError("Forecast path is not given but required.")

    data = []
    for station in stations:
        file = os.path.join(forecast_path, f"forecasts_{station}_test.nc")
        data_tmp = xr.open_dataarray(file)
        data.append(data_tmp.loc[:, :, ['CNN', 'orig', 'OLS']].assign_coords(station=station))
    data = xr.concat(data, dim='station').transpose('index', 'type', 'ahead', 'station')

    linetype = [':', '-.', '--', '-.', ':']
    bins = np.arange(0, math.ceil(data.max().max()) + 1, 1).astype(int)
    xlabel = 'forecast concentration (in ppb)'
    ylabel = 'observed concentration (in ppb)'

    data = data.stack(z=['index', 'station'])
    data.coords['z'] = range(len(data.coords['z']))
    data.loc[pred_name, ...] = data.loc[pred_name, ...].to_pandas().T.apply(pd.cut, bins=bins, labels=bins[1:]).T.values
    quantile_panel = xr.DataArray(np.full([data.ahead.shape[0], len(q), bins[1:].shape[0]], np.nan),
                                  coords=[data.ahead, q, bins[1:]], dims=['ahead', 'quantiles', 'categories'])
    quantile_panel = quantile_panel.transpose('ahead', 'quantiles', 'categories')
    for bin in bins[1:]:
        quantile_panel.loc[..., bin] = data.loc[ref_name, ...].where(data.loc[pred_name, ...] == bin).quantile(q, dim=['z']).T
    pdf_pages = PdfPages(plot_path)

    y2_max = 0
    for iteration, d in enumerate(data.ahead):
        logging.debug(f"plotting {d.values} time step(s) ahead")
        ax = quantile_panel.loc[d, ...].rolling(categories=rolling_window, center=True).mean().to_pandas().T.plot(style=linetype, color='black', legend=False)
        ax2 = ax.twinx()
        ax.plot([0, bins.max()], [0, bins.max()], color='k', label='reference 1:1', linewidth=.8)
        handles, labels = ax.get_legend_handles_labels()
        data.loc[pred_name, d, :].to_pandas().hist(bins=bins, ax=ax2, color='k', alpha=.3, grid=False, rwidth=1)
        plt.legend(handles[:3]+[handles[-1]], ('.10th and .90th quantiles', '.25th and .75th quantiles', '.50th quantile', 'reference 1:1'), loc='upper left', fontsize='large')
        ax.set(xlim=(0, bins.max()), ylim=(0, bins.max()))
        ax.set_xlabel(xlabel, fontsize='x-large')
        ax.tick_params(axis='x', which='major', labelsize=15)
        ax.set_ylabel(ylabel, fontsize='x-large')
        ax.tick_params(axis='y', which='major', labelsize=15)
        ax2.yaxis.label.set_color('gray')
        ax2.tick_params(axis='y', colors='gray')
        ax2.yaxis.labelpad = -15
        ax2.set_yscale('log')
        if iteration == 0:
            y2_max = ax2.get_ylim()[1]+100
        ax2.set(ylim=(0, y2_max*10**8), yticks=np.logspace(0, 4, 5))
        ax2.set_ylabel('              sample size', fontsize='x-large')
        ax2.tick_params(axis='y', which='major', labelsize=15)
        title = f"{d.values} time step(s) ahead{f' ({season})' if len(season) > 0 else ''}"
        plt.title(title)
        pdf_pages.savefig()
    pdf_pages.close()
    plt.close('all')