__author__ = 'Felix Kleinert, Lukas Leufen'
__date__ = '2019-10-16'

import datetime as dt
from functools import reduce
import logging
import os
from typing import Union, List, Iterable

import numpy as np
import pandas as pd
import xarray as xr

from src import join, helpers
from src import statistics

# define a more general date type for type hinting
date = Union[dt.date, dt.datetime]
str_or_list = Union[str, List[str]]


class DataPrep(object):
    """
    This class prepares data to be used in neural networks. The instance searches for local stored data, that meet the
    given demands. If no local data is found, the DataPrep instance will load data from TOAR database and store this
    data locally to use the next time. For the moment, there is only support for daily aggregated time series. The
    aggregation can be set manually and differ for each variable.

    After data loading, different data pre-processing steps can be executed to prepare the data for further
    applications. Especially the following methods can be used for the pre-processing step:
    - interpolate: interpolate between data points by using xarray's interpolation method
    - standardise: standardise data to mean=1 and std=1, centralise to mean=0, additional methods like normalise on
      interval [0, 1] are not implemented yet.
    - make window history: represent the history (time steps before) for training/ testing; X
    - make labels: create target vector with given leading time steps for training/ testing; y
    - remove Nans jointly from desired input and output, only keeps time steps where no NaNs are present in X AND y. Use
      this method after the creation of the window history and labels to clean up the data cube.

    To create a DataPrep instance, it is needed to specify the stations by id (e.g. "DEBW107"), its network (e.g. UBA,
    "Umweltbundesamt") and the variables to use. Further options can be set in the instance.
    * `statistics_per_var`: define a specific statistic to extract from the TOAR database for each variable.
    * `start`: define a start date for the data cube creation. Default: Use the first entry in time series
    * `end`: set the end date for the data cube. Default: Use last date in time series.
    * `store_data_locally`: store recently downloaded data on local disk. Default: True
    * set further parameters for xarray's interpolation methods to modify the interpolation scheme

    """

    def __init__(self, path: str, network: str, station: Union[str, List[str]], variables: List[str],
                 station_type: str = None, **kwargs):
        self.path = os.path.abspath(path)
        self.network = network
        self.station = helpers.to_list(station)
        self.variables = variables
        self.station_type = station_type
        self.mean = None
        self.std = None
        self.history = None
        self.label = None
        self.observation = None
        self.kwargs = kwargs
        self.data = None
        self.meta = None
        self._transform_method = None
        self.statistics_per_var = kwargs.get("statistics_per_var", None)
        self.sampling = kwargs.get("sampling", "daily")
        if self.statistics_per_var is not None or self.sampling == "hourly":
            self.load_data()
        else:
            raise NotImplementedError("Either select hourly data or provide statistics_per_var.")

    def load_data(self):
        """
        Load data and meta data either from local disk (preferred) or download new data from TOAR database if no local
        data is  available. The latter case, store downloaded data locally if wished (default yes).
        """
        helpers.check_path_and_create(self.path)
        file_name = self._set_file_name()
        meta_file = self._set_meta_file_name()
        if self.kwargs.get('overwrite_local_data', False):
            logging.debug(f"overwrite_local_data is true, therefore reload {file_name} from JOIN")
            if os.path.exists(file_name):
                os.remove(file_name)
            if os.path.exists(meta_file):
                os.remove(meta_file)
            self.download_data(file_name, meta_file)
            logging.debug("loaded new data from JOIN")
        else:
            try:
                logging.debug(f"try to load local data from: {file_name}")
                data = self._slice_prep(xr.open_dataarray(file_name))
                self.data = self.check_for_negative_concentrations(data)
                self.meta = pd.read_csv(meta_file, index_col=0)
                if self.station_type is not None:
                    self.check_station_meta()
                logging.debug("loading finished")
            except FileNotFoundError as e:
                logging.warning(e)
                self.download_data(file_name, meta_file)
                logging.debug("loaded new data from JOIN")

    def download_data(self, file_name, meta_file):
        data, self.meta = self.download_data_from_join(file_name, meta_file)
        data = self._slice_prep(data)
        self.data = self.check_for_negative_concentrations(data)

    def check_station_meta(self):
        """
        Search for the entries in meta data and compare the value with the requested values. Raise a FileNotFoundError
        if the values mismatch.
        """
        check_dict = {"station_type": self.station_type, "network_name": self.network}
        for (k, v) in check_dict.items():
            if self.meta.at[k, self.station[0]] != v:
                logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
                              f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new "
                              f"grapping from web.")
                raise FileNotFoundError

    def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]:
        """
        Download data from TOAR database using the JOIN interface.
        :param file_name:
        :param meta_file:
        :return:
        """
        df_all = {}
        df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var,
                                      station_type=self.station_type, network_name=self.network, sampling=self.sampling)
        df_all[self.station[0]] = df
        # convert df_all to xarray
        xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()}
        xarr = xr.Dataset(xarr).to_array(dim='Stations')
        if self.kwargs.get('store_data_locally', True):
            # save locally as nc/csv file
            xarr.to_netcdf(path=file_name)
            meta.to_csv(meta_file)
        return xarr, meta

    def _set_file_name(self):
        all_vars = sorted(self.statistics_per_var.keys())
        return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc")

    def _set_meta_file_name(self):
        all_vars = sorted(self.statistics_per_var.keys())
        return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv")

    def __repr__(self):
        return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \
               f"variables={self.variables}, station_type={self.station_type}, **{self.kwargs})"

    def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
                    **kwargs):
        """
        (Copy paste from dataarray.interpolate_na)
        Interpolate values according to different methods.

        :param dim:
                Specifies the dimension along which to interpolate.
        :param method:
                {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
                          'polynomial', 'barycentric', 'krog', 'pchip',
                          'spline', 'akima'}, optional
                    String indicating which method to use for interpolation:

                    - 'linear': linear interpolation (Default). Additional keyword
                      arguments are passed to ``numpy.interp``
                    - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
                      'polynomial': are passed to ``scipy.interpolate.interp1d``. If
                      method=='polynomial', the ``order`` keyword argument must also be
                      provided.
                    - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their
                      respective``scipy.interpolate`` classes.
        :param limit:
                    default None
                    Maximum number of consecutive NaNs to fill. Must be greater than 0
                    or None for no limit.
        :param use_coordinate:
                default True
                    Specifies which index to use as the x values in the interpolation
                    formulated as `y = f(x)`. If False, values are treated as if
                    eqaully-spaced along `dim`. If True, the IndexVariable `dim` is
                    used. If use_coordinate is a string, it specifies the name of a
                    coordinate variariable to use as the index.
        :param kwargs:
        :return: xarray.DataArray
        """

        self.data = self.data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate,
                                             **kwargs)

    @staticmethod
    def check_inverse_transform_params(mean, std, method) -> None:
        msg = ""
        if method in ['standardise', 'centre'] and mean is None:
            msg += "mean, "
        if method == 'standardise' and std is None:
            msg += "std, "
        if len(msg) > 0:
            raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}")

    def inverse_transform(self) -> None:
        """
        Perform inverse transformation
        :return:
        """

        def f_inverse(data, mean, std, method_inverse):
            if method_inverse == 'standardise':
                return statistics.standardise_inverse(data, mean, std), None, None
            elif method_inverse == 'centre':
                return statistics.centre_inverse(data, mean), None, None
            elif method_inverse == 'normalise':
                raise NotImplementedError
            else:
                raise NotImplementedError

        if self._transform_method is None:
            raise AssertionError("Inverse transformation method is not set. Data cannot be inverse transformed.")
        self.check_inverse_transform_params(self.mean, self.std, self._transform_method)
        self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method)
        self._transform_method = None

    def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean = None, std=None) -> None:
        """
        This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
        and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
        (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This
        method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the
        internal transform method, internal mean and internal standard deviation weren't set ('inverse=True').

        :param string/int dim: This param is not used for inverse transformation.
                | for xarray.DataArray as string: name of dimension which should be standardised
                | for pandas.DataFrame as int: axis of dimension which should be standardised
        :param method: Choose the transformation method from 'standardise' and 'centre'. 'normalise' is not implemented
                    yet. This param is not used for inverse transformation.
        :param inverse: Switch between transformation and inverse transformation.
        :return: xarray.DataArrays or pandas.DataFrames:
                #. mean: Mean of data
                #. std: Standard deviation of data
                #. data: Standardised data
        """

        def f(data):
            if method == 'standardise':
                return statistics.standardise(data, dim)
            elif method == 'centre':
                return statistics.centre(data, dim)
            elif method == 'normalise':
                # use min/max of data or given min/max
                raise NotImplementedError
            else:
                raise NotImplementedError

        def f_apply(data):
            if method == "standardise":
                return mean, std, statistics.standardise_apply(data, mean, std)
            elif method == "centre":
                return mean, None, statistics.centre_apply(data, mean)
            else:
                raise NotImplementedError

        if not inverse:
            if self._transform_method is not None:
                raise AssertionError(f"Transform method is already set. Therefore, data was already transformed with "
                                     f"{self._transform_method}. Please perform inverse transformation of data first.")
            self.mean, self.std, self.data = locals()["f" if mean is None else "f_apply"](self.data)
            self._transform_method = method
        else:
            self.inverse_transform()

    def get_transformation_information(self, variable):
        try:
            mean = self.mean.sel({'variables': variable}).values
        except AttributeError:
            mean = None
        try:
            std = self.std.sel({'variables': variable}).values
        except AttributeError:
            std = None
        return mean, std, self._transform_method

    def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
        """
        This function uses shifts the data window+1 times and returns a xarray which has a new dimension 'window'
        containing the shifted data. This is used to represent history in the data. Results are stored in self.history .

        :param dim_name_of_inputs: Name of dimension which contains the input variables
        :param window: number of time steps to look back in history
                Note: window will be treated as negative value. This should be in agreement with looking back on
                a time line. Nonetheless positive values are allowed but they are converted to its negative
                expression
        :param dim_name_of_shift: Dimension along shift will be applied
        """
        window = -abs(window)
        self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables})

    def shift(self, dim: str, window: int) -> xr.DataArray:
        """
        This function uses xarray's shift function multiple times to represent history (if window <= 0)
        or lead time (if window > 0)
        :param dim: dimension along shift is applied
        :param window: number of steps to shift (corresponds to the window length)
        :return:
        """
        start = 1
        end = 1
        if window <= 0:
            start = window
        else:
            end = window + 1
        res = []
        for w in range(start, end):
            res.append(self.data.shift({dim: -w}))
        window_array = self.create_index_array('window', range(start, end))
        res = xr.concat(res, dim=window_array)
        return res

    def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, window: int) -> None:
        """
        This function creates a xarray.DataArray containing labels

        :param dim_name_of_target: Name of dimension which contains the target variable
        :param target_var: Name of target variable in 'dimension'
        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
        :param window: lead time of label
        """
        window = abs(window)
        self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var})

    def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
        """
        This function creates a xarray.DataArray containing labels

        :param dim_name_of_target: Name of dimension which contains the target variable
        :param target_var: Name of target variable(s) in 'dimension'
        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
        """
        self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var})

    def remove_nan(self, dim: str) -> None:
        """
        All NAs slices in dim which contain nans in self.history or self.label are removed in both data sets.
        This is done to present only a full matrix to keras.fit.

        :param dim:
        :return:
        """
        intersect = []
        if (self.history is not None) and (self.label is not None):
            non_nan_history = self.history.dropna(dim=dim)
            non_nan_label = self.label.dropna(dim=dim)
            non_nan_observation = self.observation.dropna(dim=dim)
            intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values))

        min_length = self.kwargs.get("min_length", 0)
        if len(intersect) < max(min_length, 1):
            self.history = None
            self.label = None
            self.observation = None
        else:
            self.history = self.history.sel({dim: intersect})
            self.label = self.label.sel({dim: intersect})
            self.observation = self.observation.sel({dim: intersect})

    @staticmethod
    def create_index_array(index_name: str, index_value: Iterable[int]) -> xr.DataArray:
        """
        This Function crates a 1D xarray.DataArray with given index name and value

        :param index_name:
        :param index_value:
        :return:
        """
        ind = pd.DataFrame({'val': index_value}, index=index_value)
        res = xr.Dataset.from_dataframe(ind).to_array().rename({'index': index_name}).squeeze(dim='variable', drop=True)
        res.name = index_name
        return res

    def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray:
        """
        This function prepares all settings for slicing and executes _slice
        :param data:
        :param coord: name of axis to slice
        :return:
        """
        start = self.kwargs.get('start', data.coords[coord][0].values)
        end = self.kwargs.get('end', data.coords[coord][-1].values)
        return self._slice(data, start, end, coord)

    @staticmethod
    def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray:
        """
        This function slices through a given data_item (for example select only values of 2011)
        :param data:
        :param start:
        :param end:
        :param coord: name of axis to slice
        :return:
        """
        return data.loc[{coord: slice(str(start), str(end))}]

    def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
        """
        This function sets all negative concentrations to zero. Names of all concentrations are extracted from
        https://join.fz-juelich.de/services/rest/surfacedata/ #2.1 Parameters
        :param data:
        :param minimum:
        :return:
        """
        chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5",
                     "propane", "so2", "toluene"]
        used_chem_vars = list(set(chem_vars) & set(self.variables))
        data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
        return data

    def get_transposed_history(self):
        return self.history.transpose("datetime", "window", "Stations", "variables").copy()

    def get_transposed_label(self):
        return self.label.squeeze("Stations").transpose("datetime", "window").copy()


if __name__ == "__main__":
    dp = DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
    print(dp)