diff --git a/video_prediction_tools/utils/runscript_generator/config_postprocess.py b/video_prediction_tools/utils/runscript_generator/config_postprocess.py index eff2694bbd1dc381726e54bb4fb46bf0af5ddb4c..544337834fd4a9a1241f8e20e2cf186d8a2265ce 100755 --- a/video_prediction_tools/utils/runscript_generator/config_postprocess.py +++ b/video_prediction_tools/utils/runscript_generator/config_postprocess.py @@ -76,24 +76,32 @@ class Config_Postprocess(Config_runscript_base): self.model = os.path.basename(dir_base) # List the subdirectories... _ = Config_Postprocess.get_subdir_list(dir_base) - # ... and obtain the checkpoint directory + + # Chose the checkpoint directory + ckp_req_str = "Chose a checkpoint directory from the list above:" + ckp_req_err = NotADirectoryError("Could not find the passed directory.") + dir_base = Config_Postprocess.keyboard_interaction(ckp_req_str, Config_Postprocess.check_dir, ckp_req_err, + prefix2arg=dir_base+"/", ntries=2) + # List the subdirectories... + _ = Config_Postprocess.get_subdir_list(dir_base) + # ... and obtain the model directory with checkpoints trained_dir_req_str = "Choose a trained model from the experiment list above:" trained_err = FileNotFoundError("No trained model parameters found.") - self.checkpoint_dir = Config_Postprocess.keyboard_interaction(trained_dir_req_str, Config_Postprocess.check_traindir, trained_err, ntries=3, prefix2arg=dir_base+"/") + # get the relevant information from checkpoint_dir in order to construct source_dir and results_dir # (following naming convention) cp_dir_split = Config_Postprocess.path_rec_split(self.checkpoint_dir) cp_dir_split = list(filter(None, cp_dir_split)) # get rid of empty list elements - base_dir, exp_dir_base, exp_dir = "/"+os.path.join(*cp_dir_split[:-4]), cp_dir_split[-3], cp_dir_split[-1] + base_dir, exp_dir_base, exp_dir = "/"+os.path.join(*cp_dir_split[:-4]), cp_dir_split[-3], cp_dir_split[-2] self.runscript_target = self.rscrpt_tmpl_prefix + self.dataset + "_" + exp_dir + ".sh" # Set results_dir - self.results_dir = os.path.join(base_dir, "results", exp_dir_base, self.model, exp_dir) + self.results_dir = os.path.join(base_dir, "results", exp_dir_base,self.model, exp_dir) return # Decide if quick evaluation should be performed