__author__ = 'Felix Kleinert, Lukas Leufen'
__date__ = '2020-02-07'


from src.run_modules.run_environment import RunEnvironment
from src.data_handling.data_generator import DataGenerator
import numpy as np
import logging
import xarray as xr
import os
import re


class BootStraps(RunEnvironment):

    def __init__(self):

        super().__init__()
        self.test_data: DataGenerator = self.data_store.get("generator", "general.test")
        self.number_bootstraps = 200
        self.bootstrap_path = self.data_store.get("bootstrap_path", "general")
        self.create_shuffled_data()

    def create_shuffled_data(self):
        """
        Create shuffled data. Use original test data, add dimension 'boots' with length number of bootstraps and insert
        randomly selected variables. If there is a suitable local file for requested window size and number of
        bootstraps, no additional file will be created inside this function.
        """
        variables_str = '_'.join(sorted(self.test_data.variables))
        window = self.test_data.window_history_size
        for station in self.test_data.stations:
            valid, nboot = self.valid_bootstrap_file(station, variables_str, window)
            if not valid:
                logging.info(f'create bootstap data for {station}')
                hist, _ = self.test_data[station]
                data = hist.copy()
                file_name = f"{station}_{variables_str}_hist{window}_nboots{nboot}_shuffled.nc"
                file_path = os.path.join(self.bootstrap_path, file_name)
                data = data.expand_dims({'boots': range(nboot)}, axis=-1)
                shuffled_variable = np.full(data.shape, np.nan)
                for i, var in enumerate(data.coords['variables']):
                    single_variable = data.sel(variables=var).values
                    shuffled_variable[..., i, :] = self.shuffle_single_variable(single_variable)
                shuffled_data = xr.DataArray(shuffled_variable, coords=data.coords, dims=data.dims)
                shuffled_data.to_netcdf(file_path)

    def valid_bootstrap_file(self, station, variables, window):
        """
        Compare local bootstrap file with given settings for station, variables, window and number of bootstraps. If a
        match was found, this method returns a tuple (True, None). In any other case, it returns (False, max_nboot),
        where max_nboot is the highest boot number found in the local storage. A match is defined so that the window
        length is ge than given window size form args and the number of boots is also ge than the given number of boots
        from this class. Furthermore, this functions deletes local files, if the match the station pattern but don't fit
        the window and bootstrap condition. This is performed, because it is assumed, that the corresponding file will
        be created with a longer or at least same window size and numbers of bootstraps.
        :param station:
        :param variables:
        :param window:
        :return:
        """
        regex = re.compile(rf"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled*")
        max_nboot = self.number_bootstraps
        for file in os.listdir(self.bootstrap_path):
            match = regex.match(file)
            if match:
                window_file = int(match.group(1))
                nboot_file = int(match.group(2))
                max_nboot = max([max_nboot, nboot_file])
                if (window_file >= window) and (nboot_file >= self.number_bootstraps):
                    return True, None
                else:
                    os.remove(os.path.join(self.bootstrap_path, file))
        return False, max_nboot

    @staticmethod
    def shuffle_single_variable(data: np.ndarray) -> np.ndarray:
        orig_shape = data.shape
        size = orig_shape
        return np.random.choice(data.reshape(-1,), size=size)


if __name__ == "__main__":

    from src.run_modules.experiment_setup import ExperimentSetup
    from src.run_modules.run_environment import RunEnvironment
    from src.run_modules.pre_processing import PreProcessing

    formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
    logging.basicConfig(format=formatter, level=logging.INFO)

    with RunEnvironment():
        ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013'],
                        station_type='background', trainable=True, window_history_size=9)
        PreProcessing()

        BootStraps()