From 5ab1dc387f7974622139daff21c36cd9a75a9985 Mon Sep 17 00:00:00 2001
From: Michael <m.langguth@fz-juelich.de>
Date: Tue, 9 Feb 2021 20:39:35 +0100
Subject: [PATCH] Add handling of basic source directory to all classes for the
 runscripts of the workflow steps.

---
 .../config_runscripts/config_extraction.py    | 15 ++++++--
 .../config_preprocess_step1.py                | 36 ++++++++++++++-----
 .../config_preprocess_step2.py                | 29 ++++++++-------
 .../config_runscripts/config_training.py      | 12 ++++---
 4 files changed, 64 insertions(+), 28 deletions(-)

diff --git a/video_prediction_tools/config_runscripts/config_extraction.py b/video_prediction_tools/config_runscripts/config_extraction.py
index 2157b946..4035fed2 100644
--- a/video_prediction_tools/config_runscripts/config_extraction.py
+++ b/video_prediction_tools/config_runscripts/config_extraction.py
@@ -48,6 +48,11 @@ class Config_Extraction(Config_runscript_base):
         self.year = Config_Extraction.keyboard_interaction(year_req_str, Config_Extraction.check_year,
                                                            year_err, ntries = 2, test_arg="2012")
 
+        # final check for input data
+        path_year = os.path.join(self.source_dir, self.year)
+        if not Config_Extraction.check_data_indir(path_year, silent=True, recursive=False):
+            raise FileNotFoundError("Cannot retrieve input data from {0}".format(path_year))
+
         # set destination directory based on base directory which can be retrieved from the template runscript
         base_dir = Config_Extraction.get_var_from_runscript(self.runscript_template, "destination_dir")
         self.destination_dir = os.path.join(base_dir, "extracted_data", self.year)
@@ -57,21 +62,25 @@ class Config_Extraction(Config_runscript_base):
     #
     # auxiliary functions for keyboard interaction
     @staticmethod
-    def check_data_indir(indir, silent=False):
+    def check_data_indir(indir, silent=False, recursive=True):
         """
         Check recursively for existence era5 netCDF-files in indir.
         This is just a simplified check, i.e. the script will fail if the directory tree is not
         built up like '<indir>/YYYY/MM/'.
-        Also used in Config_preprocess1!
         :param indir: path to passed input directory
         :param silent: flag if print-statement are executed
+        :param recursive: flag if one-level (!) recursive search should be performed
         :return: status with True confirming success
         """
         status = False
         if os.path.isdir(indir):
             # the built-in 'any'-function has a short-sircuit mechanism, i.e. returns True
             # if the first True element is met
-            fexist = any(glob.glob(os.path.join(indir, "**", "*era5*.nc"), recursive=True))
+            if recursive
+                fexist = any(glob.glob(os.path.join(indir, "*", "*era5*.nc")))
+            else:
+                fexist = any(glob.glob(os.path.join(indir, "*era5*.nc")))
+
             if fexist:
                 status = True
             else:
diff --git a/video_prediction_tools/config_runscripts/config_preprocess_step1.py b/video_prediction_tools/config_runscripts/config_preprocess_step1.py
index 9abb124d..e58aac5b 100644
--- a/video_prediction_tools/config_runscripts/config_preprocess_step1.py
+++ b/video_prediction_tools/config_runscripts/config_preprocess_step1.py
@@ -27,12 +27,14 @@ class Config_Preprocess1(Config_runscript_base):
         self.runscript_template = self.rscrpt_tmpl_prefix + self.dataset + "_step1" + self.suffix_template
         self.runscript_target = self.rscrpt_tmpl_prefix + self.dataset + "_step1" + ".sh"
         # initialize additional runscript-specific attributes to be set via keyboard interaction
+        self.destination_dir = None
         self.years = None
         self.variables = [None] * self.nvars
         self.lat_inds = [-1 -1] #[np.nan, np.nan]
         self.lon_inds = [-1 -1] #[np.nan, np.nan]
         # list of variables to be written to runscript
-        self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "years", "variables", "lat_inds", "lon_inds"]
+        self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "destination_dir", "years", "variables",
+                                "lat_inds", "lon_inds"]
         # copy over method for keyboard interaction
         self.run_config = Config_Preprocess1.run_preprocess1
     #
@@ -44,19 +46,29 @@ class Config_Preprocess1(Config_runscript_base):
         :return: all attributes of class Config_Preprocess1 are set
         """
         # get source_dir
-        dataset_req_str = "Enter the path where the extracted ERA5 netCDF-files are located:\n"
+        source_dir_base = Config_Preprocess1.handle_source_dir(self, "extractedData")
+
+        dataset_req_str = "Choose a subdirectory listed above where the extracted ERA5 files are located:\n"
         dataset_err = FileNotFoundError("Cannot retrieve extracted ERA5 netCDF-files from passed path.")
 
         self.source_dir = Config_Preprocess1.keyboard_interaction(dataset_req_str, Config_Preprocess1.check_data_indir,
-                                                                  dataset_err, ntries=3)
+                                                                  dataset_err, ntries=3, suffix2arg=source_dir_base+"/")
 
         # get years for preprcessing step 1
         years_req_str = "Enter a comma-separated sequence of years (format: YYYY):\n"
         years_err = ValueError("Cannot get years for preprocessing.")
         years_str = Config_Preprocess1.keyboard_interaction(years_req_str, Config_Preprocess1.check_years,
-                                                           years_err, ntries=2)
+                                                            years_err, ntries=2)
 
         self.years = [year.strip() for year in years_str.split(",")]
+        # final check data availability for each year
+        for year in self.years:
+            year_path = os.path.join(self.source_dir, year)
+            status = Config_Preprocess1.check_data_indir(year_path, recursive=False)
+            if status:
+                print("Data availability checked for year {0}".format(year))
+            else:
+                raise FileNotFoundError("Cannot retrieve ERA5 netCDF-files from {0}".format(year_path))
 
         # get variables for later training
         print("**** Info ****\n List of known variables which can be processed")
@@ -94,26 +106,32 @@ class Config_Preprocess1(Config_runscript_base):
         lon_inds_list = lon_inds_str.split(",")
         self.lon_inds = [ind.strip() for ind in lon_inds_list]
 
+        # set destination directory based on base directory which can be retrieved from the template runscript
+        base_dir = Config_Preprocess1.get_var_from_runscript(self.runscript_template, "destination_dir")
+        self.destination_dir = os.path.join(base_dir, "preprocessedData", "era5-Y{0}-{1}M01to12"
+                                            .format(min(years), max(years)))
+
     #
     # -----------------------------------------------------------------------------------
     #
     # auxiliary functions for keyboard interaction
     @staticmethod
-    def check_data_indir(indir, silent=False):
+    def check_data_indir(indir, silent=False, recursive=True):
         """
         Check recursively for existence era5 netCDF-files in indir.
-        This is just a simplified check, i.e. the script will fail if the directory tree is not
-        built up like '<indir>/YYYY/MM/'.
-        Also used in Config_preprocess1!
         :param indir: path to passed input directory
         :param silent: flag if print-statement are executed
+        :param recursive: flag if recursive search should be performed
         :return: status with True confirming success
         """
         status = False
         if os.path.isdir(indir):
             # the built-in 'any'-function has a short-sircuit mechanism, i.e. returns True
             # if the first True element is met
-            fexist = any(glob.glob(os.path.join(indir, "**", "*era5*.nc"), recursive=True))
+            if recursive:
+                fexist = any(glob.glob(os.path.join(indir, "**", "*era5*.nc"), recursive=True))
+            else:
+                fexist = any(glob.glob(os.path.join(indir, "*era5*.nc")))
             if fexist:
                 status = True
             else:
diff --git a/video_prediction_tools/config_runscripts/config_preprocess_step2.py b/video_prediction_tools/config_runscripts/config_preprocess_step2.py
index 9ed2443f..4db2fbd9 100644
--- a/video_prediction_tools/config_runscripts/config_preprocess_step2.py
+++ b/video_prediction_tools/config_runscripts/config_preprocess_step2.py
@@ -37,7 +37,7 @@ class Config_Preprocess2(Config_runscript_base):
     def run_preprocess2(self):
         """
         Runs the keyboard interaction for Preprocessing step 2
-        :return: all attributes of class Config_Preprocess1 are set
+        :return: all attributes of class Config_Preprocess2 set
         """
         # decide which dataset is used
         dset_type_req_str = "Enter the name of the dataset for which TFrecords should be prepard for training:\n"
@@ -50,25 +50,32 @@ class Config_Preprocess2(Config_runscript_base):
                                   self.suffix_template
         self.runscript_target = self.rscrpt_tmpl_prefix + self.dataset + "_step2" + ".sh"
 
-        # get source dir
+        # get source dir (relative to base_dir_source!)
+        source_dir_base = Config_Preprocess2.handle_source_dir(self, "preprocessedData")
+
         if self.dataset == "era5":
             file_type = "ERA5 pickle-files are"
         elif self.dataset == "moving_mnist":
             file_type = "The movingMNIST data file is"
-        source_req_str = "Enter the path where the extracted "+file_type+" located:\n"
+        source_req_str = "Choose a subdirectory listed above to {0} where the extracted {1} located:\n"\
+                             .format(base_dir_source, file_type)
         source_err = FileNotFoundError("Cannot retrieve "+file_type+" from passed path.")
 
         self.source_dir = Config_Preprocess2.keyboard_interaction(source_req_str, Config_Preprocess2.check_data_indir,
-                                                                  source_err, ntries=3)
+                                                                  source_err, ntries=3, suffix2arg=source_dir_base+"/")
+        # Note: At this stage, self.source_dir is a top-level directory.
+        # TFrecords are assumed to live in tfrecords-subdirectory,
+        # input files are assumed to live in pickle-subdirectory
+        self.destination_dir = os.path.join(source_dir_base, "tfrecords")
+        self.source_dir = os.path.join(self.source_dir, "pickle")
 
-        base_dir, _ = os.path.split(self.source_dir)
-        self.destination_dir = os.path.join(base_dir, "tfrecords")
         # check if expected data is available in source_dir (depending on dataset)
         # The following files are expected:
         # * ERA5: pickle-files
-        # * moving_MNIST: singgle npy-file
+        # * moving_MNIST: single npy-file
         if self.dataset == "era5":
-            if not any(glob.glob(os.path.join(self.source_dir, "**", "*X*.pkl"), recursive=True)):
+            # pickle files are expected to be stored in yearly-subdirectories, i.e. we need a wildcard here
+            if not any(glob.glob(os.path.join(self.source_dir, "*", "*X*.pkl"))):
                 raise FileNotFoundError("Could not find any pickle-files under '{0}'".format(self.source_dir) +
                                         "which are expected for the ERA5-dataset.".format(self.source_dir))
         elif self.dataset == "moving_mnist":
@@ -111,11 +118,9 @@ class Config_Preprocess2(Config_runscript_base):
     @staticmethod
     def check_data_indir(indir, silent=False):
         """
-        Check if the passed directory exists. If the required input data is really available depends on the dataset
-        and therefore has to be done afterwards
-        :param indir: path to input directory from keyboard interaction
+        Rough check of passed directory (if it exist at all)
+        :param indir: path to passed input directory
         :param silent: flag if print-statement are executed
-        :return: status with True confirming success
         """
         status = True
         if not os.path.isdir(indir):
diff --git a/video_prediction_tools/config_runscripts/config_training.py b/video_prediction_tools/config_runscripts/config_training.py
index d014d4c7..93e77091 100644
--- a/video_prediction_tools/config_runscripts/config_training.py
+++ b/video_prediction_tools/config_runscripts/config_training.py
@@ -20,7 +20,7 @@ class Config_Train(Config_runscript_base):
     # !!! Important note !!!
     # As long as we don't have runscript templates for all the datasets listed in known_datasets
     # or a generic template runscript, we need the following manual list
-    allowed_datasets = ["era5","moving_mnist"]  # known_datasets().keys
+    allowed_datasets = ["era5", "moving_mnist"]  # known_datasets().keys
 
     def __init__(self, venv_name, lhpc):
         super().__init__(venv_name, lhpc)
@@ -56,15 +56,19 @@ class Config_Train(Config_runscript_base):
         self.runscript_target = self.rscrpt_tmpl_prefix + self.dataset + ".sh"
 
         # get the source directory
+        # get source dir (relative to base_dir_source!)
+        source_dir_base = Config_Train.handle_source_dir(self, "preprocessedData")
 
-        expdir_req_str = "Enter the path to the preprocessed data (directory where tf-records files are located):\n"
+        expdir_req_str = "Choose a subdirectory listed above where the preprocessed TFrecords are located:\n"
         expdir_err = FileNotFoundError("Could not find any tfrecords.")
 
         self.source_dir = Config_Train.keyboard_interaction(expdir_req_str, Config_Train.check_expdir,
-                                                            expdir_err, ntries=3)
+                                                            expdir_err, ntries=3, suffix2arg=source_dir_base+"/")
+        # expand source_dir by tfrecords-subdirectory
+        self.source_dir = os.path.join(self.source_dir, "tfrecords")
 
         # split up directory path in order to retrieve exp_dir used for setting up the destination directory
-        exp_dir_split = path_rec_split(self.source_dir)
+        exp_dir_split = Config_Train.path_rec_split(self.source_dir)
         index = [idx for idx, s in enumerate(exp_dir_split) if self.dataset in s]
         if index == []:
             raise ValueError(
-- 
GitLab