Skip to content
Snippets Groups Projects
Commit ef2755df authored by Michael Langguth's avatar Michael Langguth
Browse files

Adaptions and corrections to config_postprocess.py.

parent ce200cbb
Branches
Tags
No related merge requests found
Pipeline #59472 passed
...@@ -30,8 +30,8 @@ class Config_Postprocess(Config_runscript_base): ...@@ -30,8 +30,8 @@ class Config_Postprocess(Config_runscript_base):
self.checkpoint_dir = None self.checkpoint_dir = None
self.destination_dir = None self.destination_dir = None
# list of variables to be written to runscript # list of variables to be written to runscript
self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "destination_dir", self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "results_dir",
"checkpoint_dir", "model", "dataset"] "checkpoint_dir", "model"]
# copy over method for keyboard interaction # copy over method for keyboard interaction
self.run_config = Config_Postprocess.run_postprocess self.run_config = Config_Postprocess.run_postprocess
# #
...@@ -43,7 +43,7 @@ class Config_Postprocess(Config_runscript_base): ...@@ -43,7 +43,7 @@ class Config_Postprocess(Config_runscript_base):
:return: all attributes of class postprocess are set :return: all attributes of class postprocess are set
""" """
# decide which dataset is used # decide which dataset is used
dset_type_req_str = "Enter the name of the dataset on which training was performed:\n" dset_type_req_str = "Enter the name of the dataset on which training was performed:"
dset_err = ValueError("Please select a dataset from the ones listed above.") dset_err = ValueError("Please select a dataset from the ones listed above.")
self.dataset = Config_Postprocess.keyboard_interaction(dset_type_req_str, Config_Postprocess.check_dataset, self.dataset = Config_Postprocess.keyboard_interaction(dset_type_req_str, Config_Postprocess.check_dataset,
...@@ -56,21 +56,22 @@ class Config_Postprocess(Config_runscript_base): ...@@ -56,21 +56,22 @@ class Config_Postprocess(Config_runscript_base):
# get the 'checkpoint-directory', i.e. the directory where the trained model parameters are stored # get the 'checkpoint-directory', i.e. the directory where the trained model parameters are stored
# Note that the remaining information (model, results-directory etc.) can be retrieved form it!!! # Note that the remaining information (model, results-directory etc.) can be retrieved form it!!!
trained_dir_req_str = "Enter the absolute (!) path to the model checkpoint directory" + \ trained_dir_req_str = "Enter the absolute (!) path to the model checkpoint directory" + \
" for which postprocessing should be done:\n" " for which postprocessing should be done:"
trained_err = FileNotFoundError("No trained model parameters found.") trained_err = FileNotFoundError("No trained model parameters found.")
self.checkpoint_dir = Config_Postprocess.keyboard_interaction(trained_dir_req_str, self.checkpoint_dir = Config_Postprocess.keyboard_interaction(trained_dir_req_str,
Config_Postprocess.check_traindir, Config_Postprocess.check_traindir,
trained_err, ntries=3) trained_err, ntries=3)
# get the relevant information from checlpoint_dir in order to construct source_dir and results_dir # get the relevant information from checkpoint_dir in order to construct source_dir and results_dir
# (following naming convention) # (following naming convention)
cp_dir_split = Config_Postprocess.path_rec_split(self.checkpoint_dir) cp_dir_split = Config_Postprocess.path_rec_split(self.checkpoint_dir)
cp_dir_split = list(filter(None, cp_dir_split)) # get rid of empty list elements
base_dir, exp_dir_base, exp_dir = os.path.join(*cp_dir_split[:-3]), cp_dir_split[-3], cp_dir_split[-1] base_dir, exp_dir_base, exp_dir = "/"+os.path.join(*cp_dir_split[:-4]), cp_dir_split[-3], cp_dir_split[-1]
self.model = Config_Postprocess.check_model(cp_dir_split[-2]) self.model = Config_Postprocess.check_model(cp_dir_split[-2])
self.source_dir = Config_Postprocess.check_source(os.path.join(base_dir, exp_dir_base)) self.source_dir = Config_Postprocess.check_source(os.path.join(base_dir, "preprocessedData", exp_dir_base))
self.destination_dir = os.path.join(base_dir, "results", exp_dir_base, self.model, exp_dir) self.destination_dir = os.path.join(base_dir, "results", exp_dir_base, self.model, exp_dir)
# #
# ----------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------
...@@ -128,7 +129,7 @@ class Config_Postprocess(Config_runscript_base): ...@@ -128,7 +129,7 @@ class Config_Postprocess(Config_runscript_base):
""" """
if not model_in in Config_Postprocess.list_models: if not model_in in Config_Postprocess.list_models:
print("**** Known models ****") print("**** Known models ****")
for model in Config_Postprocess: print(model) for model in Config_Postprocess.list_models: print(model)
raise ValueError("{0} is an unknown model (see list of known models above).".format(model_in)) raise ValueError("{0} is an unknown model (see list of known models above).".format(model_in))
else: else:
pass pass
...@@ -144,7 +145,7 @@ class Config_Postprocess(Config_runscript_base): ...@@ -144,7 +145,7 @@ class Config_Postprocess(Config_runscript_base):
:param source_dir_in: input directory to be checked :param source_dir_in: input directory to be checked
:return: returns source_dir_in when check is passed successfully :return: returns source_dir_in when check is passed successfully
""" """
real_dir = os.path.join(source_dir_in, "tfrecords") real_dir = os.path.join(source_dir_in, "tfrecords", "tfrecords")
if os.path.isdir(real_dir): if os.path.isdir(real_dir):
file_list = glob.glob(os.path.join(real_dir, "sequence*.tfrecords")) file_list = glob.glob(os.path.join(real_dir, "sequence*.tfrecords"))
if len(file_list) > 0: if len(file_list) > 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment