Skip to content
Snippets Groups Projects
Select Git revision
  • a8f12a620b4646c55ca57ef4e3a1c9304b2de4f9
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

test_experiment_setup.py

Blame
  • config_preprocess_step2.py 7.31 KiB
    """
    Child class used for configuring the preprocessing step 2 runscript of the workflow.
    """
    __author__ = "Michael Langguth"
    __date__ = "2021-01-28"
    
    # import modules
    import os, glob
    from data_preprocess.dataset_options import known_datasets
    from config_utils import Config_runscript_base    # import parent class
    
    class Config_Preprocess2(Config_runscript_base):
    
        cls_name = "Config_Preprocess2"#.__name__
    
        # !!! Important note !!!
        # As long as we don't have runscript templates for all the datasets listed in known_datasets
        # or a generic template runscript, we need the following manual list
        allowed_datasets = ["era5","moving_mnist"]  # known_datasets().keys
    
        def __init__(self, venv_name, lhpc):
            super().__init__(venv_name, lhpc)
    
            # initialize attributes related to runscript name
            self.long_name_wrk_step = "Preproccessing step 2"
            self.rscrpt_tmpl_prefix = "preprocess_data_"
            # initialize additional runscript-specific attributes to be set via keyboard interaction
            self.destination_dir = None
            self.sequence_length = None  # only needed for ERA5
            # list of variables to be written to runscript
            self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "destination_dir"]     # appended for ERA5 dataset
            # copy over method for keyboard interaction
            self.run_config = Config_Preprocess2.run_preprocess2
        #
        # -----------------------------------------------------------------------------------
        #
        def run_preprocess2(self):
            """
            Runs the keyboard interaction for Preprocessing step 2
            :return: all attributes of class Config_Preprocess2 set
            """
            method_name = Config_Preprocess2.run_preprocess2.__name__
    
            # decide which dataset is used
            dset_type_req_str = "Enter the name of the dataset for which TFrecords should be prepard for training:"
            dset_err = ValueError("Please select a dataset from the ones listed above.")
    
            self.dataset = Config_Preprocess2.keyboard_interaction(dset_type_req_str, Config_Preprocess2.check_dataset,
                                                                   dset_err, ntries=3)
            # now, we are also ready to set the correct name of the runscript template and the target
            self.runscript_template = self.rscrpt_tmpl_prefix + self.dataset + "_step2"+\
                                      self.suffix_template
            self.runscript_target = self.rscrpt_tmpl_prefix + self.dataset + "_step2" + ".sh"
    
            # get source dir (relative to base_dir_source!)
            source_dir_base = Config_Preprocess2.handle_source_dir(self, "preprocessedData")
    
            if self.dataset == "era5":
                file_type = "ERA5 pickle-files are"
            elif self.dataset == "moving_mnist":
                file_type = "The movingMNIST data file is"
            source_req_str = "Choose a subdirectory listed above to {0} where the extracted {1} located:"\
                                 .format(source_dir_base, file_type)
            source_err = FileNotFoundError("Cannot retrieve "+file_type+" from passed path.")
    
            self.source_dir = Config_Preprocess2.keyboard_interaction(source_req_str, Config_Preprocess2.check_data_indir,
                                                                      source_err, ntries=3, suffix2arg=source_dir_base+"/")
            # Note: At this stage, self.source_dir is a top-level directory.
            # TFrecords are assumed to live in tfrecords-subdirectory,
            # input files are assumed to live in pickle-subdirectory
            self.destination_dir = os.path.join(self.source_dir, "tfrecords")
            self.source_dir = os.path.join(self.source_dir, "pickle")
    
            # check if expected data is available in source_dir (depending on dataset)
            # The following files are expected:
            # * ERA5: pickle-files
            # * moving_MNIST: single npy-file
            if self.dataset == "era5":
                # pickle files are expected to be stored in yearly-subdirectories, i.e. we need a wildcard here
                if not any(glob.iglob(os.path.join(self.source_dir, "*", "*X*.pkl"))):
                    raise FileNotFoundError("%{0}: Could not find any pickle-files under '{1}'".format(method_name, self.source_dir) +
                                            " which are expected for the ERA5-dataset.")
            elif self.dataset == "moving_mnist":
                if not os.path.isfile(os.path.join(self.source_dir, "mnist_test_seq.npy")):
                    raise FileNotFoundError("%{0}: Could not find expected file 'mnist_test_seq.npy' under {1}"
                                            .format(method_name, self.source_dir))
    
            # final keyboard interaction when ERA5-dataset is used
            if self.dataset == "era5":
                # get desired sequence length
                seql_req_str = "Enter desired total sequence length (i.e. number of frames/images):"
                seql_err = ValueError("sequence length must be an integer and larger than 2.")
    
                seql_str = Config_Preprocess2.keyboard_interaction(seql_req_str, Config_Preprocess2.get_seq_length,
                                                                   seql_err)
                self.sequence_length = int(seql_str)
    
            # list of variables to be written to runscript
            if self.dataset == "era5": self.list_batch_vars.append("sequence_length")
        #
        # -----------------------------------------------------------------------------------
        #
        @staticmethod
        # dataset used for training
        def check_dataset(dataset_name, silent=False):
            # NOTE: Templates are only available for ERA5 and moving_MNIST.
            #       After adding further templates or making the template generic,
            #       the latter part of the if-clause can be removed
            #       and further adaptions are required in the configuration chain
            if not dataset_name in Config_Preprocess2.allowed_datasets:
                if not silent:
                    print("The following dataset can be used for preproessing step 2:")
                    for dataset_avail in Config_Preprocess2.allowed_datasets: print("* " + dataset_avail)
                return False
            else:
                return True
        #
        # -----------------------------------------------------------------------------------
        #
        @staticmethod
        def check_data_indir(indir, silent=False):
            """
            Rough check of passed directory (if it exist at all)
            :param indir: path to passed input directory
            :param silent: flag if print-statement are executed
            """
            status = True
            if not os.path.isdir(indir):
                status = False
                if not silent: print("Could not find data directory '{0}'.".format(indir))
    
            return status
        #
        # -----------------------------------------------------------------------------------
        #
        @staticmethod
        def get_seq_length(seq_length, silent=False):
            """
            Check if passed sequence length is larger than 1 (lower limit for meaningful prediction)
            :param seq_length: sequence length from keyboard interaction
            :param silent: flag if print-statement are executed
            :return: status with True confirming success
            """
            status = False
            if seq_length.strip().isnumeric():
                if int(seq_length) >= 2:
                    status = True
    
            return status
    #
    # -----------------------------------------------------------------------------------
    #