Skip to content
Snippets Groups Projects
Select Git revision
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
40 results

update_badge.sh

Blame
  • bootstraps.py 7.68 KiB
    __author__ = 'Felix Kleinert, Lukas Leufen'
    __date__ = '2020-02-07'
    
    
    from src.run_modules.run_environment import RunEnvironment
    from src.data_handling.data_generator import DataGenerator
    import numpy as np
    import logging
    import dask.array as da
    import xarray as xr
    import os
    import re
    from src import helpers
    
    
    class BootStrapGenerator:
    
        def __init__(self, orig_generator, boots, chunksize, bootstrap_path):
            self.orig_generator: DataGenerator = orig_generator
            self.stations = self.orig_generator.stations
            self.variables = self.orig_generator.variables
            self.boots = boots
            self.chunksize = chunksize
            self.bootstrap_path = bootstrap_path
            self._iterator = 0
            self.bootstrap_meta = []
    
        def __len__(self):
            """
            display the number of stations
            """
            return len(self.orig_generator)*self.boots*len(self.variables)
    
        def get_generator(self):
            """
            This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
            the history and label data of this generator.
            :return:
            """
            while True:
                for i, data in enumerate(self.orig_generator):
                    station = self.orig_generator.get_station_key(i)
                    logging.info(f"station: {station}")
                    hist, label = data
                    len_of_label = len(label)
                    shuffled_data = self.load_boot_data(station)
                    for var in self.variables:
                        logging.info(f"  var: {var}")
                        for boot in range(self.boots):
                            logging.debug(f"boot: {boot}")
                            boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var))
                            shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
                            boot_hist = boot_hist.combine_first(shuffled_var)
                            boot_hist = boot_hist.sortby("variables")
                            self.bootstrap_meta.extend([var]*len_of_label)
                            yield boot_hist, label
                return
    
        def load_boot_data(self, station):
            files = os.listdir(self.bootstrap_path)
            regex = re.compile(rf"{station}_\w*\.nc")
            file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0])
            shuffled_data = xr.open_dataarray(file_name, chunks=100)
            return shuffled_data
    
    
    class BootStraps(RunEnvironment):
    
        def __init__(self, data, bootstrap_path, number_bootstraps=10):
    
            super().__init__()
            self.data: DataGenerator = data
            self.number_bootstraps = number_bootstraps
            self.bootstrap_path = bootstrap_path
            self.chunks = self.get_chunk_size()
            self.create_shuffled_data()
            self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.chunks, self.bootstrap_path)
    
        def get_boot_strap_meta(self):
            return self._boot_strap_generator.bootstrap_meta
    
        def boot_strap_generator(self):
            return self._boot_strap_generator.get_generator()
    
        def get_boot_strap_generator_length(self):
            return self._boot_strap_generator.__len__()
    
        def get_chunk_size(self):
            hist, _ = self.data[0]
            return (100, *hist.shape[1:], self.number_bootstraps)
    
        def create_shuffled_data(self):
            """
            Create shuffled data. Use original test data, add dimension 'boots' with length number of bootstraps and insert
            randomly selected variables. If there is a suitable local file for requested window size and number of
            bootstraps, no additional file will be created inside this function.
            """
            logging.info("create shuffled bootstrap data")
            variables_str = '_'.join(sorted(self.data.variables))
            window = self.data.window_history_size
            for station in self.data.stations:
                valid, nboot = self.valid_bootstrap_file(station, variables_str, window)
                if not valid:
                    logging.info(f'create bootstap data for {station}')
                    hist, _ = self.data[station]
                    data = hist.copy()
                    file_name = f"{station}_{variables_str}_hist{window}_nboots{nboot}_shuffled.nc"
                    file_path = os.path.join(self.bootstrap_path, file_name)
                    data = data.expand_dims({'boots': range(nboot)}, axis=-1)
                    shuffled_variable = []
                    for i, var in enumerate(data.coords['variables']):
                        single_variable = data.sel(variables=var).values
                        shuffled_variable.append(self.shuffle_single_variable(single_variable, chunks=(100, *data.shape[1:3], data.shape[-1])))
                    shuffled_variable_da = da.stack(shuffled_variable, axis=-2, ).rechunk("auto")
                    shuffled_data = xr.DataArray(shuffled_variable_da, coords=data.coords, dims=data.dims)
                    shuffled_data.to_netcdf(file_path)
    
        def valid_bootstrap_file(self, station, variables, window):
            """
            Compare local bootstrap file with given settings for station, variables, window and number of bootstraps. If a
            match was found, this method returns a tuple (True, None). In any other case, it returns (False, max_nboot),
            where max_nboot is the highest boot number found in the local storage. A match is defined so that the window
            length is ge than given window size form args and the number of boots is also ge than the given number of boots
            from this class. Furthermore, this functions deletes local files, if the match the station pattern but don't fit
            the window and bootstrap condition. This is performed, because it is assumed, that the corresponding file will
            be created with a longer or at least same window size and numbers of bootstraps.
            :param station:
            :param variables:
            :param window:
            :return:
            """
            regex = re.compile(rf"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled")
            max_nboot = self.number_bootstraps
            for file in os.listdir(self.bootstrap_path):
                match = regex.match(file)
                if match:
                    window_file = int(match.group(1))
                    nboot_file = int(match.group(2))
                    max_nboot = max([max_nboot, nboot_file])
                    if (window_file >= window) and (nboot_file >= self.number_bootstraps):
                        return True, None
                    else:
                        os.remove(os.path.join(self.bootstrap_path, file))
            return False, max_nboot
    
        @staticmethod
        def shuffle_single_variable(data: da.array, chunks) -> np.ndarray:
            size = data.shape
            return da.random.choice(data.reshape(-1,), size=size, chunks=chunks)
    
    
    if __name__ == "__main__":
    
        from src.run_modules.experiment_setup import ExperimentSetup
        from src.run_modules.run_environment import RunEnvironment
        from src.run_modules.pre_processing import PreProcessing
    
        formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
        logging.basicConfig(format=formatter, level=logging.INFO)
    
        with RunEnvironment() as run_env:
            ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013'],
                            station_type='background', trainable=True, window_history_size=9)
            PreProcessing()
    
            data = run_env.data_store.get("generator", "general.test")
            path = run_env.data_store.get("bootstrap_path", "general")
            number_bootstraps = 10
    
            boots = BootStraps(data, path, number_bootstraps)
            for b in boots.boot_strap_generator():
                a, c = b
            logging.info(f"len is {len(boots.get_boot_strap_meta())}")