diff --git a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py index 51d3de9f71c8c6d9b756ee800cc74aab49ae3413..6780a32c76f830e399368da79be60076b56761e6 100755 --- a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py +++ b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py @@ -20,7 +20,7 @@ class Config_Preprocess1(Config_runscript_base): cls_name = "Config_Preprocess1"#.__name__ - nvars = 3 # number of variables required for training + nvars_default = 3 # number of variables required for training def __init__(self, venv_name, lhpc): super().__init__(venv_name, lhpc) @@ -35,7 +35,7 @@ class Config_Preprocess1(Config_runscript_base): # initialize additional runscript-specific attributes to be set via keyboard interaction self.destination_dir = None self.years = None - self.variables = [None] * self.nvars + self.variables = [] self.sw_corner = [-999., -999.] # [np.nan, np.nan] self.nyx = [-999., -999.] # [np.nan, np.nan] # list of variables to be written to runscript @@ -54,8 +54,17 @@ class Config_Preprocess1(Config_runscript_base): """ method_name = Config_Preprocess1.run_preprocess1.__name__ - # get source_dir (no user interaction needed when directory tree is fixed) - self.source_dir = Config_Preprocess1.handle_source_dir(self, "extractedData") + src_dir_req_str = "Enter path to directory where netCDF-files of the ERA5 dataset are located " + \ + "(in yearly directories.). Just press enter if the default should be used." + sorurce_dir_err = NotADirectoryError("Passed directory does not exist.") + source_dir_str = Config_Preprocess1.keyboard_interaction(src_dir_req_str, Config_Preprocess1.src_dir_check, + sorurce_dir_err, ntries=3) + + if not source_dir_str: + # standard source_dir + self.source_dir = Config_Preprocess1.handle_source_dir(self, "extractedData") + else: + self.source_dir = source_dir_str # get years for preprocessing step 1 years_req_str = "Enter a comma-separated sequence of years from list above:" @@ -86,8 +95,9 @@ class Config_Preprocess1(Config_runscript_base): vars_err, ntries=2) vars_list = vars_str.split(",") + vars_list = [var.strip().lower() for var in vars_list] if len(vars_list) == 1: - self.variables = vars_list * Config_Preprocess1.nvars + self.variables = vars_list * Config_Preprocess1.nvars_default else: self.variables = [var.strip() for var in vars_list] @@ -169,6 +179,29 @@ class Config_Preprocess1(Config_runscript_base): str(year))) # auxiliary functions for keyboard interaction + @staticmethod + def src_dir_check(srcdir, silent=False): + """ + Checks if source directory exists. Also allows for empty strings. In this case, a default of the source + directory must be applied. + :param srcdir: directory path under which ERA5 netCDF-data is stored + :param silent: flag if print-statement are executed + :return: status with True confirming success + """ + method = Config_Preprocess1.src_dir_check.__name__ + + status = False + if srcdir: + if os.path.isdir(srcdir): + status = True + else: + if not silent: + print("%{0}: '{1}' does not exist.".format(method, srcdir)) + else: + status = True + + return status + @staticmethod def check_data_indir(indir, silent=False, recursive=True): """ @@ -216,7 +249,8 @@ class Config_Preprocess1(Config_runscript_base): if not status: inds_bad = [i for i, e in enumerate(check_years) if e] #np.where(~np.array(check_years))[0] if not silent: - print("%{0}: The following comma-separated elements could not be interpreted as valid years:".format(method)) + print("%{0}: The following comma-separated elements could not be interpreted as valid years:" + .format(method)) for ind in inds_bad: print(years_list[ind]) return status @@ -245,15 +279,15 @@ class Config_Preprocess1(Config_runscript_base): check_vars = [var.strip().lower() in known_vars for var in vars_list] status = all(check_vars) if not status: - inds_bad = [i for i, e in enumerate(check_vars) if e] # np.where(~np.array(check_vars))[0] + inds_bad = [i for i, e in enumerate(check_vars) if e] # np.where(~np.array(check_vars))[0] if not silent: print("%{0}: The following comma-separated elements are unknown variables:".format(method)) for ind in inds_bad: print(vars_list[ind]) return status - if not (len(check_vars) == Config_Preprocess1.nvars or len(check_vars) == 1): - if not silent: print("%{0}: Unexpected number of variables passed ({1} vs. {2}).".format(method, len(check_vars), Config_Preprocess1.nvars)) + if not len(check_vars) >= 1: + if not silent: print("%{0}: Pass at least one input variable".format(method)) status = False return status @@ -321,9 +355,11 @@ class Config_Preprocess1(Config_runscript_base): else: if not silent: if not check_nyx[0]: - print("%{0}: Number of grid points in meridional direction must be smaller than {1:d}".format(method, ny_max)) + print("%{0}: Number of grid points in meridional direction must be smaller than {1:d}" + .format(method, ny_max)) if not check_nyx[1]: - print("%{0}: Number of grid points in zonal direction must be smaller than {1:d}".format(method, nx_max)) + print("%{0}: Number of grid points in zonal direction must be smaller than {1:d}" + .format(method, nx_max)) else: if not silent: print("%{0}: Number of grid points must be integers.".format(method))