Skip to content
Snippets Groups Projects
Select Git revision
  • master
  • singularity
  • conda
  • singularity-no-venv
4 results

setup.sh

Blame
  • Forked from Stefan Kesselheim / sc_venv_template
    Source project has a limited visibility.
    pre_processing.py 13.33 KiB
    __author__ = "Lukas Leufen, Felix Kleinert"
    __date__ = '2019-11-25'
    
    
    import logging
    import os
    from typing import Tuple, Dict, List
    
    import numpy as np
    import pandas as pd
    
    from src.data_handling.data_generator import DataGenerator
    from src.helpers import TimeTracking, check_path_and_create
    from src.join import EmptyQueryResult
    from src.run_modules.run_environment import RunEnvironment
    
    DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
    DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length",
                           "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation",
                           "extreme_values", "extremes_on_right_tail_only"]
    
    
    class PreProcessing(RunEnvironment):
    
        """
        Pre-process your data by using this class. It includes time tracking and uses the experiment setup to look for data
        and stores it if not already in local disk. Further, it provides this data as a generator and checks for valid
        stations (in this context: valid=data available). Finally, it splits the data into valid training, validation and
        testing subsets.
        """
    
        def __init__(self):
    
            # create run framework
            super().__init__()
    
            #
            self._run()
    
        def _run(self):
            args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="preprocessing")
            kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="preprocessing")
            stations = self.data_store.get("stations")
            valid_stations = self.check_valid_stations(args, kwargs, stations, load_tmp=False, save_tmp=False, name="all")
            self.data_store.set("stations", valid_stations)
            self.split_train_val_test()
            self.report_pre_processing()
    
        def report_pre_processing(self):
            logging.debug(20 * '##')
            n_train = len(self.data_store.get('generator', 'train'))
            n_val = len(self.data_store.get('generator', 'val'))
            n_test = len(self.data_store.get('generator', 'test'))
            n_total = n_train + n_val + n_test
            logging.debug(f"Number of all stations: {n_total}")
            logging.debug(f"Number of training stations: {n_train}")
            logging.debug(f"Number of val stations: {n_val}")
            logging.debug(f"Number of test stations: {n_test}")
            logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'test')[0][0].shape}"
                          f"{self.data_store.get('generator', 'test')[0][1].shape}")
            self.create_latex_report()
    
        def create_latex_report(self):
            """
            This function creates tables with information on the station meta data and a summary on subset sample sizes.
    
            * station_sample_size.md: see table below
            * station_sample_size.tex: same as table below, but as latex table
            * station_sample_size_short.tex: reduced size table without any meta data besides station ID, as latex table
    
            All tables are stored inside experiment_path inside the folder latex_report. The table format (e.g. which meta
            data is highlighted) is currently hardcoded to have a stable table style. If further styles are needed, it is
            better to add an additional style than modifying the existing table styles.
    
            | stat. ID   | station_name                              |   station_lon |   station_lat |   station_alt |   train |   val |   test |
            |------------|-------------------------------------------|---------------|---------------|---------------|---------|-------|--------|
            | DEBW013    | Stuttgart Bad Cannstatt                   |        9.2297 |       48.8088 |           235 |    1434 |   712 |   1080 |
            | DEBW076    | Baden-Baden                               |        8.2202 |       48.7731 |           148 |    3037 |   722 |    710 |
            | DEBW087    | Schwäbische_Alb                           |        9.2076 |       48.3458 |           798 |    3044 |   714 |   1087 |
            | DEBW107    | Tübingen                                  |        9.0512 |       48.5077 |           325 |    1803 |   715 |   1087 |
            | DEBY081    | Garmisch-Partenkirchen/Kreuzeckbahnstraße |       11.0631 |       47.4764 |           735 |    2935 |   525 |    714 |
            | # Stations | nan                                       |      nan      |      nan      |           nan |       6 |     6 |      6 |
            | # Samples  | nan                                       |      nan      |      nan      |           nan |   12253 |  3388 |   4678 |
    
            """
            meta_data = ['station_name', 'station_lon', 'station_lat', 'station_alt']
            meta_round = ["station_lon", "station_lat", "station_alt"]
            precision = 4
            path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
            check_path_and_create(path)
            set_names = ["train", "val", "test"]
            df = pd.DataFrame(columns=meta_data+set_names)
            for set_name in set_names:
                data: DataGenerator = self.data_store.get("generator", set_name)
                for station in data.stations:
                    df.loc[station, set_name] = data.get_data_generator(station).get_transposed_label().shape[0]
                    if df.loc[station, meta_data].isnull().any():
                        df.loc[station, meta_data] = data.get_data_generator(station).meta.loc[meta_data].values.flatten()
                df.loc["# Samples", set_name] = df.loc[:, set_name].sum()
                df.loc["# Stations", set_name] = df.loc[:, set_name].count()
            df[meta_round] = df[meta_round].astype(float).round(precision)
            df.sort_index(inplace=True)
            df = df.reindex(df.index.drop(["# Stations", "# Samples"]).to_list() + ["# Stations", "# Samples"], )
            df.index.name = 'stat. ID'
            column_format = self.create_column_format_for_tex(df)
            df.to_latex(os.path.join(path, "station_sample_size.tex"), na_rep='---', column_format=column_format)
            df.to_markdown(open(os.path.join(path, "station_sample_size.md"), mode="w", encoding='utf-8'), tablefmt="github")
            df_nometa = df.drop(meta_data, axis=1)
            df_nometa.to_latex(os.path.join(path, "station_sample_size_short.tex"), na_rep='---',
                               column_format=column_format)
            df_descr = df_nometa.iloc[:-2].astype('float32').describe(
                percentiles=[.05, .1, .25, .5, .75, .9, .95]).astype('int32')
            df_descr = pd.concat([df_nometa.loc[['# Samples']], df_descr]).T
            df_descr.rename(columns={"# Samples": "sum"}, inplace=True)
            column_format = self.create_column_format_for_tex(df_descr)
            df_descr.to_latex(os.path.join(path, "station_describe_short.tex"), na_rep='---',
                              column_format=column_format)
    
        @staticmethod
        def create_column_format_for_tex(df: pd.DataFrame) -> str:
            """
            Creates column format for latex table based on the shape of a given DataFrame.
    
            Calculates number of columns and uses 'c' as column position. First element is set to 'l', last to 'r'
            """
            column_format = np.repeat('c', df.shape[1] + 1)
            column_format[0] = 'l'
            column_format[-1] = 'r'
            column_format = ''.join(column_format.tolist())
            return column_format
    
        def split_train_val_test(self) -> None:
            """
            Splits all subsets. Currently: train, val, test and train_val (actually this is only the merge of train and val,
            but as an separate generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train
            subset needs always to be executed at first, to set a proper transformation.
            """
            fraction_of_training = self.data_store.get("fraction_of_training")
            stations = self.data_store.get("stations")
            train_index, val_index, test_index, train_val_index = self.split_set_indices(len(stations), fraction_of_training)
            subset_names = ["train", "val", "test", "train_val"]
            if subset_names[0] != "train":  # pragma: no cover
                raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset"
                                     f"order was: {subset_names}.")
            for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names):
                self.create_set_split(ind, scope)
    
        @staticmethod
        def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice, slice]:
            """
            create the training, validation and test subset slice indices for given total_length. The test data consists on
            (1-fraction) of total_length (fraction*len:end). Train and validation data therefore are made from fraction of
            total_length (0:fraction*len). Train and validation data is split by the factor 0.8 for train and 0.2 for
            validation. In addition, split_set_indices returns also the combination of training and validation subset.
            :param total_length: list with all objects to split
            :param fraction: ratio between test and union of train/val data
            :return: slices for each subset in the order: train, val, test, train_val
            """
            pos_test_split = int(total_length * fraction)
            train_index = slice(0, int(pos_test_split * 0.8))
            val_index = slice(int(pos_test_split * 0.8), pos_test_split)
            test_index = slice(pos_test_split, total_length)
            train_val_index = slice(0, pos_test_split)
            return train_index, val_index, test_index, train_val_index
    
        def create_set_split(self, index_list: slice, set_name) -> None:
            """
            Create the subset for given split index and stores the DataGenerator with given set name in data store as
            `generator`. Checks for all valid stations using the default (kw)args for given scope and creates the
            DataGenerator for all valid stations. Also sets all transformation information, if subset is training set. Make
            sure, that the train set is executed first, and all other subsets afterwards.
            :param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True,
                this list is ignored.
            :param set_name: name to load/save all information from/to data store.
            """
            args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name)
            kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope=set_name)
            stations = args["stations"]
            if self.data_store.get("use_all_stations_on_all_data_sets", scope=set_name):
                set_stations = stations
            else:
                set_stations = stations[index_list]
            logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}")
            set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False, name=set_name)
            self.data_store.set("stations", set_stations, scope=set_name)
            set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name)
            data_set = DataGenerator(**set_args, **kwargs)
            self.data_store.set("generator", data_set, scope=set_name)
            if set_name == "train":
                self.data_store.set("transformation", data_set.transformation)
    
        @staticmethod
        def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True, name=None):
            """
            Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given
            time range (is included in `kwargs`). The shape and the loading time are logged in debug mode.
            :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`,
                `variables`, `interpolate_dim`, `target_dim`, `target_var`).
            :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`,
                `window_lead_time`).
            :param all_stations: All stations to check.
            :param name: name to display in the logging info message
            :return: Corrected list containing only valid station IDs.
            """
            t_outer = TimeTracking()
            t_inner = TimeTracking(start=False)
            logging.info(f"check valid stations started{' (%s)' % name if name else ''}")
            valid_stations = []
    
            # all required arguments of the DataGenerator can be found in args, positional arguments in args and kwargs
            data_gen = DataGenerator(**args, **kwargs)
            for pos, station in enumerate(all_stations):
                t_inner.run()
                logging.info(f"check station {station} ({pos + 1} / {len(all_stations)})")
                try:
                    data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp,
                                                       save_local_tmp_storage=save_tmp)
                    if data.history is None:
                        raise AttributeError
                    valid_stations.append(station)
                    logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
                    logging.debug(f"{station}: loading time = {t_inner}")
                except (AttributeError, EmptyQueryResult):
                    continue
            logging.info(f"run for {t_outer} to check {len(all_stations)} station(s). Found {len(valid_stations)}/"
                         f"{len(all_stations)} valid stations.")
            return valid_stations