import logging
from typing import Any, Tuple, Dict, List

from src.data_generator import DataGenerator
from src.helpers import TimeTracking
from src.modules.run_environment import RunEnvironment
from src.datastore import NameNotFoundInDataStore, NameNotFoundInScope
from src.join import EmptyQueryResult


DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history", "window_lead_time", "statistics_per_var", "station_type"]


class PreProcessing(RunEnvironment):

    """
    Pre-process your data by using this class. It includes time tracking and uses the experiment setup to look for data
    and stores it if not already in local disk. Further, it provides this data as a generator and checks for valid
    stations (in this context: valid=data available). Finally, it splits the data into valid training, validation and
    testing subsets.
    """

    def __init__(self):

        # create run framework
        super().__init__()

        #
        self._run()

    def _create_args_dict(self, arg_list, scope="general"):
        args = {}
        for arg in arg_list:
            try:
                args[arg] = self.data_store.get(arg, scope)
            except (NameNotFoundInDataStore, NameNotFoundInScope):
                pass
        return args

    def _run(self):
        args = self._create_args_dict(DEFAULT_ARGS_LIST)
        kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST)
        valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"))
        self.data_store.put("stations", valid_stations, "general")
        self.split_train_val_test()

    def split_train_val_test(self):
        fraction_of_training = self.data_store.get("fraction_of_training", "general")
        stations = self.data_store.get("stations", "general")
        train_index, val_index, test_index = self.split_set_indices(len(stations), fraction_of_training)
        for (ind, scope) in zip([train_index, val_index, test_index], ["train", "val", "test"]):
            self.create_set_split(ind, scope)

    @staticmethod
    def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice]:
        """
        create the training, validation and test subset slice indices for given total_length. The test data consists on
        (1-fraction) of total_length (fraction*len:end). Train and validation data therefore are made from fraction of
        total_length (0:fraction*len). Train and validation data is split by the factor 0.8 for train and 0.2 for
        validation.
        :param total_length: list with all objects to split
        :param fraction: ratio between test and union of train/val data
        :return: slices for each subset in the order: train, val, test
        """
        pos_test_split = int(total_length * fraction)
        train_index = slice(0, int(pos_test_split * 0.8))
        val_index = slice(int(pos_test_split * 0.8), pos_test_split)
        test_index = slice(pos_test_split, total_length)
        return train_index, val_index, test_index

    def create_set_split(self, index_list, set_name):
        scope = f"general.{set_name}"
        args = self._create_args_dict(DEFAULT_ARGS_LIST, scope)
        kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST, scope)
        stations = args["stations"]
        if self.data_store.get("use_all_stations_on_all_data_sets", scope):
            set_stations = stations
        else:
            set_stations = stations[index_list]
        logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}")
        set_stations = self.check_valid_stations(args, kwargs, set_stations)
        self.data_store.put("stations", set_stations, scope)
        set_args = self._create_args_dict(DEFAULT_ARGS_LIST, scope)
        data_set = DataGenerator(**set_args, **kwargs)
        self.data_store.put("generator", data_set, scope)

    @staticmethod
    def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]):
        """
        Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given
        time range (is included in `kwargs`). The shape and the loading time are logged in debug mode.
        :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`,
            `variables`, `interpolate_dim`, `target_dim`, `target_var`).
        :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`,
            `window_lead_time`).
        :param all_stations: All stations to check.
        :return: Corrected list containing only valid station IDs.
        """
        t_outer = TimeTracking()
        t_inner = TimeTracking(start=False)
        logging.info("check valid stations started")
        valid_stations = []

        # all required arguments of the DataGenerator can be found in args, positional arguments in args and kwargs
        data_gen = DataGenerator(**args, **kwargs)
        for station in all_stations:
            t_inner.run()
            try:
                (history, label) = data_gen[station]
                valid_stations.append(station)
                logging.debug(f"{station}: history_shape = {history.shape}")
                logging.debug(f"{station}: loading time = {t_inner}")
            except (AttributeError, EmptyQueryResult):
                continue
        logging.info(f"run for {t_outer} to check {len(all_stations)} station(s). Found {len(valid_stations)}/"
                     f"{len(all_stations)} valid stations.")
        return valid_stations