__author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2019-11-15'

import argparse
import logging
import os
from typing import Union, Dict, Any, List

import src.configuration.path_config
from src.configuration import path_config
from src import helpers
from src.run_modules.run_environment import RunEnvironment

DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022', 'DEBY004',
                    'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039', 'DEBW038', 'DEBW081',
                    'DEBY075', 'DEBW040', 'DEBY053', 'DEBW059', 'DEBW027', 'DEBY072', 'DEBW042', 'DEBW039', 'DEBY001',
                    'DEBY113', 'DEBY089', 'DEBW024', 'DEBW004', 'DEBY037', 'DEBW056', 'DEBW029', 'DEBY068', 'DEBW010',
                    'DEBW026', 'DEBY002', 'DEBY079', 'DEBW084', 'DEBY049', 'DEBY031', 'DEBW019', 'DEBW001', 'DEBY063',
                    'DEBY005', 'DEBW046', 'DEBW103', 'DEBW052', 'DEBW034', 'DEBY088', ]
DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
                        'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
                        'pblheight': 'maximum'}
DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"}
DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries",
                     "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles",
                     "PlotAvailability"]
DEFAULT_HPC_LOGIN_LIST = ["ju", "hdfmll"]  # ju[wels} #hdfmll(ogin)
DEFAULT_HPC_HOST_LIST = ["jw", "hdfmlc"]  # first part of node names for Juwels (jw[comp], hdfmlc(ompute).


class ExperimentSetup(RunEnvironment):
    """
    Set up the model.

    Schedule of experiment setup:
        * set up experiment path
        * set up data path (according to host system)
        * set up forecast, bootstrap and plot path (inside experiment path)
        * set all parameters given in args (or use default values)
        * check target variable
        * check `variables` and `statistics_per_var` parameter for consistency

    Sets
        * `data_path` [.]
        * `create_new_model` [.]
        * `bootstrap_path` [.]
        * `trainable` [.]
        * `fraction_of_training` [.]
        * `extreme_values` [train]
        * `extremes_on_right_tail_only` [train]
        * `upsampling` [train]
        * `permute_data` [train]
        * `experiment_name` [.]
        * `experiment_path` [.]
        * `plot_path` [.]
        * `forecast_path` [.]
        * `stations` [.]
        * `network` [.]
        * `station_type` [.]
        * `statistics_per_var` [.]
        * `variables` [.]
        * `start` [.]
        * `end` [.]
        * `window_history_size` [.]
        * `overwrite_local_data` [preprocessing]
        * `sampling` [.]
        * `transformation` [., preprocessing]
        * `target_var` [.]
        * `target_dim` [.]
        * `window_lead_time` [.]

        # interpolation
        self._set_param("dimensions", dimensions, default={'new_index': ['datetime', 'Stations']})
        self._set_param("interpolate_dim", interpolate_dim, default='datetime')
        self._set_param("interpolate_method", interpolate_method, default='linear')
        self._set_param("limit_nan_fill", limit_nan_fill, default=1)

        # train set parameters
        self._set_param("start", train_start, default="1997-01-01", scope="train")
        self._set_param("end", train_end, default="2007-12-31", scope="train")
        self._set_param("min_length", train_min_length, default=90, scope="train")

        # validation set parameters
        self._set_param("start", val_start, default="2008-01-01", scope="val")
        self._set_param("end", val_end, default="2009-12-31", scope="val")
        self._set_param("min_length", val_min_length, default=90, scope="val")

        # test set parameters
        self._set_param("start", test_start, default="2010-01-01", scope="test")
        self._set_param("end", test_end, default="2017-12-31", scope="test")
        self._set_param("min_length", test_min_length, default=90, scope="test")

        # train_val set parameters
        self._set_param("start", self.data_store.get("start", "train"), scope="train_val")
        self._set_param("end", self.data_store.get("end", "val"), scope="train_val")
        train_val_min_length = sum([self.data_store.get("min_length", s) for s in ["train", "val"]])
        self._set_param("min_length", train_val_min_length, default=180, scope="train_val")

        # use all stations on all data sets (train, val, test)
        self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True)

        # set post-processing instructions
        self._set_param("evaluate_bootstraps", evaluate_bootstraps, scope="general.postprocessing")
        create_new_bootstraps = max([self.data_store.get("trainable", "general"), create_new_bootstraps or False])
        self._set_param("create_new_bootstraps", create_new_bootstraps, scope="general.postprocessing")
        self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing")
        self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")

        # check variables, statistics and target variable
        self._check_target_var()
        self._compare_variables_and_statistics()









    Creates
        * plot of model architecture in `<model_name>.pdf`

    :param parser_args: argument parser, currently only accepting ``experiment_date argument`` to be used for
        experiment's name and path creation. Final experiment's name is derived from given name and the time series
        sampling as `<name>_network_<sampling>/` . All interim and final results, logging, plots, ... of this run are
        stored in this directory if not explicitly provided in kwargs. Only the data itself and data for bootstrap
        investigations are stored outside this structure.
    :param stations: list of stations or single station to use in experiment. If not provided, stations are set to
        :py:const:`default stations <DEFAULT_STATIONS>`.
    :param network: name of network to restrict to use only stations from this measurement network. Default is
        `AIRBASE` .
    :param station_type: restrict network type to one of TOAR's categories (background, traffic, industrial). Default is
        `None` to use all categories.
    :param variables: list of all variables to use. Valid names can be found in
        `Section 2.1 Parameters <https://join.fz-juelich.de/services/rest/surfacedata/>`_. If not provided, this
        parameter is filled with keys from ``statistics_per_var``.
    :param statistics_per_var: dictionary with statistics to use for variables (if data is daily and loaded from JOIN).
        If not provided, :py:const:`default statistics <DEFAULT_VAR_ALL_DICT>` is applied. ``statistics_per_var`` is
        compared with given ``variables`` and unused variables are removed. Therefore, statistics at least need to
        provide all variables from ``variables``. For more details on available statistics, we refer to
        `Section 3.3 List of statistics/metrics for stats service <https://join.fz-juelich.de/services/rest/surfacedata/>`_
        in the JOIN documentation. Valid parameter names can be found in
        `Section 2.1 Parameters <https://join.fz-juelich.de/services/rest/surfacedata/>`_.
    :param start: start date of overall data (default `"1997-01-01"`)
    :param end: end date of overall data (default `"2017-12-31"`)
    :param window_history_size: number of time steps to use for input data (default 13). Time steps `t_0 - w` to `t_0`
        are used as input data (therefore actual data size is `w+1`).
    :param target_var: target variable to predict by model, currently only a single target variable is supported.
        Because this framework was originally designed to predict ozone, default is `"o3"`.
    :param target_dim: dimension of target variable (default `"variables"`).
    :param window_lead_time: number of time steps to predict by model (default 3). Time steps `t_0+1` to `t_0+w` are
        predicted.
    :param dimensions:
    :param interpolate_dim:
    :param interpolate_method:
    :param limit_nan_fill:
    :param train_start:
    :param train_end:
    :param val_start:
    :param val_end:
    :param test_start:
    :param test_end:
    :param use_all_stations_on_all_data_sets:
    :param trainable: train a new model from scratch or resume training with existing model if `True` (default) or
        freeze loaded model and do not perform any modification on it. ``trainable`` is set to `True` if
        ``create_new_model`` is `True`.
    :param fraction_of_train: given value is used to split between test data and train data (including validation data).
        The value of ``fraction_of_train`` must be in `(0, 1)` but is recommended to be in the interval `[0.6, 0.9]`.
        Default value is `0.8`. Split between train and validation is fixed to 80% - 20% and currently not changeable.
    :param experiment_path:
    :param plot_path: path to save all plots. If left blank, this will be included in the experiment path (recommended).
        Otherwise customise the location to save all plots.
    :param forecast_path: path to save all forecasts in files. It is recommended to leave this parameter blank, all
        forecasts will be the directory `forecasts` inside the experiment path (default). For customisation, add your
        path here.
    :param overwrite_local_data: Reload input and target data from web and replace local data if `True` (default
        `False`).
    :param sampling: set temporal sampling rate of data. You can choose from daily (default), monthly, seasonal,
        vegseason, summer and annual for aggregated values and hourly for the actual values. Note, that hourly values on
        JOIN are currently not accessible from outside. To access this data, you need to add your personal token in
        :py:mod:`join settings <src.configuration.join_settings>` and make sure to untrack this file!
    :param create_new_model: determine whether a new model will be created (`True`, default) or not (`False`). If this
        parameter is set to `False`, make sure, that a suitable model already exists in the experiment path. This model
        must fit in terms of input and output dimensions as well as ``window_history_size`` and ``window_lead_time`` and
        must be implemented as a :py:mod:`model class <src.model_modules.model_class>` and imported in
        :py:mod:`model setup <src.run_modules.model_setup>`. If ``create_new_model`` is `True`, parameter ``trainable``
        is automatically set to `True` too.
    :param bootstrap_path:
    :param permute_data_on_training: shuffle train data individually for each station if `True`. This is performed each
        iteration for new, so that each sample very likely differs from epoch to epoch. Train data permutation is
        disabled (`False`) per default. If the case of extreme value manifolding, data permutation is enabled anyway.
    :param transformation: set transformation options in dictionary style. All information about transformation options
        can be found in :py:meth:`setup transformation <src.data_handling.data_generator.DataGenerator.setup_transformation>`.
        If no transformation is provided, all options are set to :py:const:`default transformation <DEFAULT_TRANSFORMATION>`.
    :param train_min_length:
    :param val_min_length:
    :param test_min_length:
    :param extreme_values: augment target samples with values of lower occurrences indicated by its normalised
        deviation from mean by manifolding. These extreme values need to be indicated by a list of thresholds. For
        each  entry in this list, all values outside an +/- interval will be added in the training (and only the
        training) set for a second time to the sample. If multiple valus are given, a sample is added for each
        exceedence once. E.g. a sample with `value=2.5` occurs twice in the training set for given
        `extreme_values=[2, 3]`, whereas a sample with `value=5` occurs three times in the training set. For default,
        upsampling of extreme values is disabled (`None`). Upsamling can be modified to manifold only values that are
        actually larger than given values from ``extreme_values`` (apply only on right side of distribution) by using
        ``extremes_on_right_tail_only``. This can be useful for positive skew variables.
    :param extremes_on_right_tail_only: applies only if ``extreme_values`` are given. If ``extremes_on_right_tail_only``
        is `True`, only manifold values that are larger than given extremes (apply upsampling only on right side of
        distribution). In default mode, this is set to `False` to manifold extremes on both sides.
    :param evaluate_bootstraps:
    :param plot_list:
    :param number_of_bootstraps:
    :param create_new_bootstraps:
    :param data_path: path to find and store meteorological and environmental / air quality data. Leave this parameter
        empty, if your host system is known and a suitable path was already hardcoded in the program (see
        :py:func:`prepare host <src.configuration.path_config.prepare_host>`).

    """

    def __init__(self,
                 parser_args=None,
                 stations: Union[str, List[str]] = None,
                 network: str = None,
                 station_type: str = None,
                 variables: Union[str, List[str]] = None,
                 statistics_per_var: Dict = None,
                 start: str = None,
                 end: str = None,
                 window_history_size: int = None,
                 target_var="o3",
                 target_dim=None,
                 window_lead_time: int = None,
                 dimensions=None,
                 interpolate_dim=None,
                 interpolate_method=None,
                 limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
                 test_end=None, use_all_stations_on_all_data_sets=True, trainable: bool = None, fraction_of_train: float = None,
                 experiment_path=None, plot_path: str = None, forecast_path: str = None, overwrite_local_data: bool = None, sampling: str = "daily",
                 create_new_model: bool = None, bootstrap_path=None, permute_data_on_training: bool = None, transformation=None,
                 train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
                 extremes_on_right_tail_only: bool = None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None,
                 create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None):

        # create run framework
        super().__init__()

        # experiment setup
        self._set_param("data_path", path_config.prepare_host(data_path=data_path, sampling=sampling))
        self._set_param("hostname", path_config.get_host())
        self._set_param("hpc_hosts", hpc_hosts, default=DEFAULT_HPC_HOST_LIST + DEFAULT_HPC_LOGIN_LIST)
        self._set_param("login_nodes", login_nodes, default=DEFAULT_HPC_LOGIN_LIST)
        self._set_param("create_new_model", create_new_model, default=True)
        if self.data_store.get("create_new_model"):
            trainable = True
        data_path = self.data_store.get("data_path")
        bootstrap_path = path_config.set_bootstrap_path(bootstrap_path, data_path, sampling)
        self._set_param("bootstrap_path", bootstrap_path)
        self._set_param("trainable", trainable, default=True)
        self._set_param("fraction_of_training", fraction_of_train, default=0.8)
        self._set_param("extreme_values", extreme_values, default=None, scope="train")
        self._set_param("extremes_on_right_tail_only", extremes_on_right_tail_only, default=False, scope="train")
        self._set_param("upsampling", extreme_values is not None, scope="train")
        upsampling = self.data_store.get("upsampling", "train")
        permute_data = False if permute_data_on_training is None else permute_data_on_training
        self._set_param("permute_data", permute_data or upsampling, scope="train")

        # set experiment name
        exp_date = self._get_parser_args(parser_args).get("experiment_date")
        exp_name, exp_path = path_config.set_experiment_name(experiment_name=exp_date, experiment_path=experiment_path,
                                                                               sampling=sampling)
        self._set_param("experiment_name", exp_name)
        self._set_param("experiment_path", exp_path)
        path_config.check_path_and_create(self.data_store.get("experiment_path"))

        # set model path
        self._set_param("model_path", None, os.path.join(exp_path, "model"))
        path_config.check_path_and_create(self.data_store.get("model_path"))

        # set plot path
        default_plot_path = os.path.join(exp_path, "plots")
        self._set_param("plot_path", plot_path, default=default_plot_path)
        path_config.check_path_and_create(self.data_store.get("plot_path"))

        # set results path
        default_forecast_path = os.path.join(exp_path, "forecasts")
        self._set_param("forecast_path", forecast_path, default_forecast_path)
        path_config.check_path_and_create(self.data_store.get("forecast_path"))

        # set logging path
        self._set_param("logging_path", None, os.path.join(exp_path, "logging"))
        path_config.check_path_and_create(self.data_store.get("logging_path"))

        # setup for data
        self._set_param("stations", stations, default=DEFAULT_STATIONS)
        self._set_param("network", network, default="AIRBASE")
        self._set_param("station_type", station_type, default=None)
        self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT)
        self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys()))
        self._set_param("start", start, default="1997-01-01")
        self._set_param("end", end, default="2017-12-31")
        self._set_param("window_history_size", window_history_size, default=13)
        self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="preprocessing")
        self._set_param("sampling", sampling)
        self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
        self._set_param("transformation", None, scope="preprocessing")

        # target
        self._set_param("target_var", target_var, default="o3")
        self._set_param("target_dim", target_dim, default='variables')
        self._set_param("window_lead_time", window_lead_time, default=3)

        # interpolation
        self._set_param("dimensions", dimensions, default={'new_index': ['datetime', 'Stations']})
        self._set_param("interpolate_dim", interpolate_dim, default='datetime')
        self._set_param("interpolate_method", interpolate_method, default='linear')
        self._set_param("limit_nan_fill", limit_nan_fill, default=1)

        # train set parameters
        self._set_param("start", train_start, default="1997-01-01", scope="train")
        self._set_param("end", train_end, default="2007-12-31", scope="train")
        self._set_param("min_length", train_min_length, default=90, scope="train")

        # validation set parameters
        self._set_param("start", val_start, default="2008-01-01", scope="val")
        self._set_param("end", val_end, default="2009-12-31", scope="val")
        self._set_param("min_length", val_min_length, default=90, scope="val")

        # test set parameters
        self._set_param("start", test_start, default="2010-01-01", scope="test")
        self._set_param("end", test_end, default="2017-12-31", scope="test")
        self._set_param("min_length", test_min_length, default=90, scope="test")

        # train_val set parameters
        self._set_param("start", self.data_store.get("start", "train"), scope="train_val")
        self._set_param("end", self.data_store.get("end", "val"), scope="train_val")
        train_val_min_length = sum([self.data_store.get("min_length", s) for s in ["train", "val"]])
        self._set_param("min_length", train_val_min_length, default=180, scope="train_val")

        # use all stations on all data sets (train, val, test)
        self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True)

        # set post-processing instructions
        self._set_param("evaluate_bootstraps", evaluate_bootstraps, scope="general.postprocessing")
        create_new_bootstraps = max([self.data_store.get("trainable", "general"), create_new_bootstraps or False])
        self._set_param("create_new_bootstraps", create_new_bootstraps, scope="general.postprocessing")
        self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing")
        self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")

        # check variables, statistics and target variable
        self._check_target_var()
        self._compare_variables_and_statistics()

    def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None:
        """Set given parameter and log in debug."""
        if value is None and default is not None:
            value = default
        self.data_store.set(param, value, scope)
        logging.debug(f"set experiment attribute: {param}({scope})={value}")

    @staticmethod
    def _get_parser_args(args: Union[Dict, argparse.Namespace]) -> Dict:
        """
        Transform args to dict if given as argparse.Namespace.

        :param args: either a dictionary or an argument parser instance

        :return: dictionary with all arguments
        """
        if isinstance(args, argparse.Namespace):
            return args.__dict__
        elif isinstance(args, dict):
            return args
        else:
            return {}

    def _compare_variables_and_statistics(self):
        """
        Compare variables and statistics.

        * raise error, if a variable is missing.
        * remove unused variables from statistics.
        """
        logging.debug("check if all variables are included in statistics_per_var")
        stat = self.data_store.get("statistics_per_var")
        var = self.data_store.get("variables")
        # too less entries, raise error
        if not set(var).issubset(stat.keys()):
            missing = set(var).difference(stat.keys())
            raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested "
                             f"variables are part of statistics_per_var. Please add also information on the missing "
                             f"statistics for the variables: {missing}")
        # too much entries, remove unused
        target_var = helpers.to_list(self.data_store.get("target_var"))
        unused_vars = set(stat.keys()).difference(set(var).union(target_var))
        if len(unused_vars) > 0:
            logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}")
            stat_new = helpers.remove_items(stat, list(unused_vars))
            self._set_param("statistics_per_var", stat_new)

    def _check_target_var(self):
        """Check if target variable is in statistics_per_var dictionary."""
        target_var = helpers.to_list(self.data_store.get("target_var"))
        stat = self.data_store.get("statistics_per_var")
        var = self.data_store.get("variables")
        if not set(target_var).issubset(stat.keys()):
            raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.")

if __name__ == "__main__":
    formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
    logging.basicConfig(format=formatter, level=logging.DEBUG)

    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
                        help="set experiment date as string")
    parser_args = parser.parse_args()
    with RunEnvironment():
        setup = ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])