__author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2019-11-15'


import argparse
import logging
import os
from typing import Union, Dict, Any

from src import helpers
from src.run_modules.run_environment import RunEnvironment

DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022', 'DEBY004',
                    'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039', 'DEBW038', 'DEBW081',
                    'DEBY075', 'DEBW040', 'DEBY053', 'DEBW059', 'DEBW027', 'DEBY072', 'DEBW042', 'DEBW039', 'DEBY001',
                    'DEBY113', 'DEBY089', 'DEBW024', 'DEBW004', 'DEBY037', 'DEBW056', 'DEBW029', 'DEBY068', 'DEBW010',
                    'DEBW026', 'DEBY002', 'DEBY079', 'DEBW084', 'DEBY049', 'DEBY031', 'DEBW019', 'DEBW001', 'DEBY063',
                    'DEBY005', 'DEBW046', 'DEBW103', 'DEBW052', 'DEBW034', 'DEBY088', ]
DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
                        'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
                        'pblheight': 'maximum'}
DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"}
DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries",
                     "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "plot_conditional_quantiles"]


class ExperimentSetup(RunEnvironment):
    """
    params:
    trainable: Train new model if true, otherwise try to load existing model
    """

    def __init__(self, parser_args=None, stations=None, network=None, station_type=None, variables=None,
                 statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None,
                 window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None,
                 limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
                 test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
                 experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
                 create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
                 evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None):

        # create run framework
        super().__init__()

        # experiment setup
        self._set_param("data_path", helpers.prepare_host(sampling=sampling))
        self._set_param("create_new_model", create_new_model, default=True)
        if self.data_store.get("create_new_model", "general"):
            trainable = True
        data_path = self.data_store.get("data_path", "general")
        bootstrap_path = helpers.set_bootstrap_path(bootstrap_path, data_path, sampling)
        self._set_param("bootstrap_path", bootstrap_path)
        self._set_param("trainable", trainable, default=True)
        self._set_param("fraction_of_training", fraction_of_train, default=0.8)
        self._set_param("permute_data", permute_data_on_training, default=False, scope="general.train")

        # set experiment name
        exp_date = self._get_parser_args(parser_args).get("experiment_date")
        exp_name, exp_path = helpers.set_experiment_name(experiment_date=exp_date, experiment_path=experiment_path,
                                                         sampling=sampling)
        self._set_param("experiment_name", exp_name)
        self._set_param("experiment_path", exp_path)
        helpers.check_path_and_create(self.data_store.get("experiment_path", "general"))

        # set plot path
        default_plot_path = os.path.join(exp_path, "plots")
        self._set_param("plot_path", plot_path, default=default_plot_path)
        helpers.check_path_and_create(self.data_store.get("plot_path", "general"))

        # set results path
        default_forecast_path = os.path.join(exp_path, "forecasts")
        self._set_param("forecast_path", forecast_path, default_forecast_path)
        helpers.check_path_and_create(self.data_store.get("forecast_path", "general"))

        # setup for data
        self._set_param("stations", stations, default=DEFAULT_STATIONS)
        self._set_param("network", network, default="AIRBASE")
        self._set_param("station_type", station_type, default=None)
        self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT)
        self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var", "general").keys()))
        self._compare_variables_and_statistics()
        self._set_param("start", start, default="1997-01-01", scope="general")
        self._set_param("end", end, default="2017-12-31", scope="general")
        self._set_param("window_history_size", window_history_size, default=13)
        self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="general.preprocessing")
        self._set_param("sampling", sampling)
        self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
        self._set_param("transformation", None, scope="general.preprocessing")

        # target
        self._set_param("target_var", target_var, default="o3")
        self._check_target_var()
        self._set_param("target_dim", target_dim, default='variables')
        self._set_param("window_lead_time", window_lead_time, default=3)

        # interpolation
        self._set_param("dimensions", dimensions, default={'new_index': ['datetime', 'Stations']})
        self._set_param("interpolate_dim", interpolate_dim, default='datetime')
        self._set_param("interpolate_method", interpolate_method, default='linear')
        self._set_param("limit_nan_fill", limit_nan_fill, default=1)

        # train set parameters
        self._set_param("start", train_start, default="1997-01-01", scope="general.train")
        self._set_param("end", train_end, default="2007-12-31", scope="general.train")

        # validation set parameters
        self._set_param("start", val_start, default="2008-01-01", scope="general.val")
        self._set_param("end", val_end, default="2009-12-31", scope="general.val")

        # test set parameters
        self._set_param("start", test_start, default="2010-01-01", scope="general.test")
        self._set_param("end", test_end, default="2017-12-31", scope="general.test")

        # train_val set parameters
        self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val")
        self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val")

        # use all stations on all data sets (train, val, test)
        self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True)

        # set post-processing instructions
        self._set_param("evaluate_bootstraps", evaluate_bootstraps, scope="general.postprocessing")
        self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing")
        self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")

    def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None:
        if value is None and default is not None:
            value = default
        self.data_store.set(param, value, scope)
        logging.debug(f"set experiment attribute: {param}({scope})={value}")

    @staticmethod
    def _get_parser_args(args: Union[Dict, argparse.Namespace]) -> Dict:
        """
        Transform args to dict if given as argparse.Namespace
        :param args: either a dictionary or an argument parser instance
        :return: dictionary with all arguments
        """
        if isinstance(args, argparse.Namespace):
            return args.__dict__
        elif isinstance(args, dict):
            return args
        else:
            return {}

    def _compare_variables_and_statistics(self):
        logging.debug("check if all variables are included in statistics_per_var")
        stat = self.data_store.get("statistics_per_var", "general")
        var = self.data_store.get("variables", "general")
        if not set(var).issubset(stat.keys()):
            missing = set(var).difference(stat.keys())
            raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested "
                             f"variables are part of statistics_per_var. Please add also information on the missing "
                             f"statistics for the variables: {missing}")

    def _check_target_var(self):
        target_var = helpers.to_list(self.data_store.get("target_var", "general"))
        stat = self.data_store.get("statistics_per_var", "general")
        var = self.data_store.get("variables", "general")
        if not set(target_var).issubset(stat.keys()):
            raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.")
        unused_vars = set(stat.keys()).difference(set(var).union(target_var))
        if len(unused_vars) > 0:
            logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}")
            stat_new = helpers.dict_pop(stat, list(unused_vars))
            self._set_param("statistics_per_var", stat_new)


if __name__ == "__main__":

    formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
    logging.basicConfig(format=formatter, level=logging.DEBUG)

    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
                        help="set experiment date as string")
    parser_args = parser.parse_args()
    with RunEnvironment():
        setup = ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])