From 7bea9e647a10397e0b755b11b69f45210c076fbf Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Wed, 29 Apr 2020 10:46:19 +0200
Subject: [PATCH] docs for preprocessing

---
 src/run_modules/pre_processing.py | 92 ++++++++++++++++++++++---------
 src/run_modules/training.py       |  6 +-
 2 files changed, 71 insertions(+), 27 deletions(-)

diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py
index 3f0ce363..bc55ad7f 100644
--- a/src/run_modules/pre_processing.py
+++ b/src/run_modules/pre_processing.py
@@ -22,18 +22,38 @@ DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_tim
 
 class PreProcessing(RunEnvironment):
     """
-    Pre-process your data by using this class. It includes time tracking and uses the experiment setup to look for data
-    and stores it if not already in local disk. Further, it provides this data as a generator and checks for valid
-    stations (in this context: valid=data available). Finally, it splits the data into valid training, validation and
-    testing subsets.
+    Pre-process your data by using this class.
+
+    Schedule of pre-processing:
+        #. load and check valid stations (either download or load from disk)
+        #. split subsets (train, val, test, train & val)
+        #. create small report on data metrics
+
+    Required objects [scope] from data store:
+        * all elements from `DEFAULT_ARGS_LIST` in scope preprocessing for general data loading
+        * all elements from `DEFAULT_ARGS_LIST` in scopes [train, val, test, train_val] for custom subset settings
+        * `fraction_of_training` [.]
+        * `experiment_path` [.]
+        * `use_all_stations_on_all_data_sets` [.]
+
+    Optional objects
+        * all elements from `DEFAULT_KWARGS_LIST` in scope preprocessing for general data loading
+        * all elements from `DEFAULT_KWARGS_LIST` in scopes [train, val, test, train_val] for custom subset settings
+
+    Sets
+        * `stations` in [., train, val, test, train_val]
+        * `generator` in [train, val, test, train_val]
+        * `transformation` [.]
+
+    Creates
+        * all input and output data in `data_path`
+        * latex reports in `experiment_path/latex_report`
+
     """
 
     def __init__(self):
-
-        # create run framework
+        """Set up and run pre-processing."""
         super().__init__()
-
-        #
         self._run()
 
     def _run(self):
@@ -46,6 +66,7 @@ class PreProcessing(RunEnvironment):
         self.report_pre_processing()
 
     def report_pre_processing(self):
+        """Log some metrics on data and create latex report."""
         logging.debug(20 * '##')
         n_train = len(self.data_store.get('generator', 'train'))
         n_val = len(self.data_store.get('generator', 'val'))
@@ -61,25 +82,33 @@ class PreProcessing(RunEnvironment):
 
     def create_latex_report(self):
         """
-        This function creates tables with information on the station meta data and a summary on subset sample sizes.
+        Create tables with information on the station meta data and a summary on subset sample sizes.
 
-        * station_sample_size.md: see table below
-        * station_sample_size.tex: same as table below, but as latex table
+        * station_sample_size.md: see table below as markdown
+        * station_sample_size.tex: same as table below as latex table
         * station_sample_size_short.tex: reduced size table without any meta data besides station ID, as latex table
 
         All tables are stored inside experiment_path inside the folder latex_report. The table format (e.g. which meta
         data is highlighted) is currently hardcoded to have a stable table style. If further styles are needed, it is
         better to add an additional style than modifying the existing table styles.
 
+        +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
         | stat. ID   | station_name                              |   station_lon |   station_lat |   station_alt |   train |   val |   test |
-        |------------|-------------------------------------------|---------------|---------------|---------------|---------|-------|--------|
+        +============+===========================================+===============+===============+===============+=========+=======+========+
         | DEBW013    | Stuttgart Bad Cannstatt                   |        9.2297 |       48.8088 |           235 |    1434 |   712 |   1080 |
+        +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
         | DEBW076    | Baden-Baden                               |        8.2202 |       48.7731 |           148 |    3037 |   722 |    710 |
+        +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
         | DEBW087    | Schwäbische_Alb                           |        9.2076 |       48.3458 |           798 |    3044 |   714 |   1087 |
+        +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
         | DEBW107    | Tübingen                                  |        9.0512 |       48.5077 |           325 |    1803 |   715 |   1087 |
+        +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
         | DEBY081    | Garmisch-Partenkirchen/Kreuzeckbahnstraße |       11.0631 |       47.4764 |           735 |    2935 |   525 |    714 |
+        +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
         | # Stations | nan                                       |      nan      |      nan      |           nan |       6 |     6 |      6 |
+        +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
         | # Samples  | nan                                       |      nan      |      nan      |           nan |   12253 |  3388 |   4678 |
+        +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
 
         """
         meta_data = ['station_name', 'station_lon', 'station_lat', 'station_alt']
@@ -113,9 +142,11 @@ class PreProcessing(RunEnvironment):
 
     def split_train_val_test(self) -> None:
         """
-        Splits all subsets. Currently: train, val, test and train_val (actually this is only the merge of train and val,
-        but as an separate generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train
-        subset needs always to be executed at first, to set a proper transformation.
+        Split data into subsets.
+
+        Currently: train, val, test and train_val (actually this is only the merge of train and val, but as an separate
+        generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs
+        always to be executed at first, to set a proper transformation.
         """
         fraction_of_training = self.data_store.get("fraction_of_training")
         stations = self.data_store.get("stations")
@@ -131,13 +162,16 @@ class PreProcessing(RunEnvironment):
     @staticmethod
     def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice, slice]:
         """
-        create the training, validation and test subset slice indices for given total_length. The test data consists on
-        (1-fraction) of total_length (fraction*len:end). Train and validation data therefore are made from fraction of
-        total_length (0:fraction*len). Train and validation data is split by the factor 0.8 for train and 0.2 for
-        validation. In addition, split_set_indices returns also the combination of training and validation subset.
+        Create the training, validation and test subset slice indices for given total_length.
+
+        The test data consists on (1-fraction) of total_length (fraction*len:end). Train and validation data therefore
+        are made from fraction of total_length (0:fraction*len). Train and validation data is split by the factor 0.8
+        for train and 0.2 for validation. In addition, split_set_indices returns also the combination of training and
+        validation subset.
 
         :param total_length: list with all objects to split
         :param fraction: ratio between test and union of train/val data
+
         :return: slices for each subset in the order: train, val, test, train_val
         """
         pos_test_split = int(total_length * fraction)
@@ -147,11 +181,13 @@ class PreProcessing(RunEnvironment):
         train_val_index = slice(0, pos_test_split)
         return train_index, val_index, test_index, train_val_index
 
-    def create_set_split(self, index_list: slice, set_name) -> None:
+    def create_set_split(self, index_list: slice, set_name: str) -> None:
         """
+        Create subsets and store in data store.
+
         Create the subset for given split index and stores the DataGenerator with given set name in data store as
-        `generator`. Checks for all valid stations using the default (kw)args for given scope and creates the
-        DataGenerator for all valid stations. Also sets all transformation information, if subset is training set. Make
+        `generator`. Check for all valid stations using the default (kw)args for given scope and create the
+        DataGenerator for all valid stations. Also set all transformation information, if subset is training set. Make
         sure, that the train set is executed first, and all other subsets afterwards.
 
         :param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True,
@@ -161,16 +197,19 @@ class PreProcessing(RunEnvironment):
         args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name)
         kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope=set_name)
         stations = args["stations"]
-        if self.data_store.get("use_all_stations_on_all_data_sets", scope=set_name):
+        if self.data_store.get("use_all_stations_on_all_data_sets"):
             set_stations = stations
         else:
             set_stations = stations[index_list]
         logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}")
+        # validate set
         set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False, name=set_name)
         self.data_store.set("stations", set_stations, scope=set_name)
+        # create set generator and store
         set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name)
         data_set = DataGenerator(**set_args, **kwargs)
         self.data_store.set("generator", data_set, scope=set_name)
+        # extract transformation from train set
         if set_name == "train":
             self.data_store.set("transformation", data_set.transformation)
 
@@ -178,8 +217,10 @@ class PreProcessing(RunEnvironment):
     def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True,
                              name=None):
         """
-        Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given
-        time range (is included in `kwargs`). The shape and the loading time are logged in debug mode.
+        Check if all given stations in `all_stations` are valid.
+
+        Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the
+        loading time are logged in debug mode.
 
         :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`,
             `variables`, `interpolate_dim`, `target_dim`, `target_var`).
@@ -187,6 +228,7 @@ class PreProcessing(RunEnvironment):
             `window_lead_time`).
         :param all_stations: All stations to check.
         :param name: name to display in the logging info message
+
         :return: Corrected list containing only valid station IDs.
         """
         t_outer = TimeTracking()
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index d54f0c6d..8bbd7723 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -19,7 +19,9 @@ from src.run_modules.run_environment import RunEnvironment
 
 class Training(RunEnvironment):
     """
-    Perform training.
+    Train your model with this module.
+
+    Schedule of training:
         #. set_generators(): set generators for training, validation and testing and distribute according to batch size
         #. make_predict_function(): create predict function before distribution on multiple nodes (detailed information
            in method description)
@@ -56,7 +58,7 @@ class Training(RunEnvironment):
     """
 
     def __init__(self):
-        """Set up training."""
+        """Set up and run training."""
         super().__init__()
         self.model: keras.Model = self.data_store.get("model", "model")
         self.train_set: Union[Distributor, None] = None
-- 
GitLab