diff --git a/video_prediction_tools/config_runscripts/config_postprocess.py b/video_prediction_tools/config_runscripts/config_postprocess.py index 33aab7f213542862743021fd87d282e4b78ad4d4..4f177ded700655c75c54a7336ac2d5257619de84 100644 --- a/video_prediction_tools/config_runscripts/config_postprocess.py +++ b/video_prediction_tools/config_runscripts/config_postprocess.py @@ -30,8 +30,8 @@ class Config_Postprocess(Config_runscript_base): self.checkpoint_dir = None self.destination_dir = None # list of variables to be written to runscript - self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "destination_dir", - "checkpoint_dir", "model", "dataset"] + self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "results_dir", + "checkpoint_dir", "model"] # copy over method for keyboard interaction self.run_config = Config_Postprocess.run_postprocess # @@ -43,7 +43,7 @@ class Config_Postprocess(Config_runscript_base): :return: all attributes of class postprocess are set """ # decide which dataset is used - dset_type_req_str = "Enter the name of the dataset on which training was performed:\n" + dset_type_req_str = "Enter the name of the dataset on which training was performed:" dset_err = ValueError("Please select a dataset from the ones listed above.") self.dataset = Config_Postprocess.keyboard_interaction(dset_type_req_str, Config_Postprocess.check_dataset, @@ -56,21 +56,22 @@ class Config_Postprocess(Config_runscript_base): # get the 'checkpoint-directory', i.e. the directory where the trained model parameters are stored # Note that the remaining information (model, results-directory etc.) can be retrieved form it!!! trained_dir_req_str = "Enter the absolute (!) path to the model checkpoint directory" + \ - " for which postprocessing should be done:\n" + " for which postprocessing should be done:" trained_err = FileNotFoundError("No trained model parameters found.") self.checkpoint_dir = Config_Postprocess.keyboard_interaction(trained_dir_req_str, Config_Postprocess.check_traindir, trained_err, ntries=3) - # get the relevant information from checlpoint_dir in order to construct source_dir and results_dir + # get the relevant information from checkpoint_dir in order to construct source_dir and results_dir # (following naming convention) cp_dir_split = Config_Postprocess.path_rec_split(self.checkpoint_dir) + cp_dir_split = list(filter(None, cp_dir_split)) # get rid of empty list elements - base_dir, exp_dir_base, exp_dir = os.path.join(*cp_dir_split[:-3]), cp_dir_split[-3], cp_dir_split[-1] + base_dir, exp_dir_base, exp_dir = "/"+os.path.join(*cp_dir_split[:-4]), cp_dir_split[-3], cp_dir_split[-1] self.model = Config_Postprocess.check_model(cp_dir_split[-2]) - self.source_dir = Config_Postprocess.check_source(os.path.join(base_dir, exp_dir_base)) + self.source_dir = Config_Postprocess.check_source(os.path.join(base_dir, "preprocessedData", exp_dir_base)) self.destination_dir = os.path.join(base_dir, "results", exp_dir_base, self.model, exp_dir) # # ----------------------------------------------------------------------------------- @@ -128,7 +129,7 @@ class Config_Postprocess(Config_runscript_base): """ if not model_in in Config_Postprocess.list_models: print("**** Known models ****") - for model in Config_Postprocess: print(model) + for model in Config_Postprocess.list_models: print(model) raise ValueError("{0} is an unknown model (see list of known models above).".format(model_in)) else: pass @@ -144,7 +145,7 @@ class Config_Postprocess(Config_runscript_base): :param source_dir_in: input directory to be checked :return: returns source_dir_in when check is passed successfully """ - real_dir = os.path.join(source_dir_in, "tfrecords") + real_dir = os.path.join(source_dir_in, "tfrecords", "tfrecords") if os.path.isdir(real_dir): file_list = glob.glob(os.path.join(real_dir, "sequence*.tfrecords")) if len(file_list) > 0: