diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 774f07650703a6e67bc6adc1df0527e673dc0a37..96f3f52b2c97b698c52338e6e9ca56b9ca89b8e7 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -1229,8 +1229,12 @@ def main(): help="Channel which is used for evaluation.") parser.add_argument("--lquick_evaluation", "-lquick", dest="lquick", default=False, action="store_true", help="Flag if (reduced) quick evaluation based on MSE is performed.") + parser.add_argument("--evaluation_metric_quick", "metric_quick", dest="metric_quick", type=str, default="mse", + help="(Only) metric to evaluate when quick evaluation (-lquick) is chosen.") args = parser.parse_args() + method = os.path.basename(__file__) + print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs(): print(k, "=", v) @@ -1239,10 +1243,13 @@ def main(): eval_metrics = args.eval_metrics results_dir = args.results_dir if args.lquick: # in case of quick evaluation, onyl evaluate MSE and modify results_dir - eval_metrics = ["mse"] + eval_metrics = [args.metric_quick] if not os.path.isfile(args.checkpoint): - raise ValueError("Pass a specific checkpoint-file for quick evaluation.") - results_dir = args.results_dir + "_{0}".format(os.path.basename(args.checkpoint)) + raise ValueError("%{0}: Pass a specific checkpoint-file for quick evaluation.".format(method)) + chp = os.path.basename(args.checkpoint) + results_dir = args.results_dir + "_{0}".format(chp) + print("%{0}: Quick evaluation is chosen. \n * evaluation metric: {0}\n".format(args.metric_quick) + + "* checkpointed model: {0}\n * no conditional quantile and forecast example plots".format(chp)) # initialize postprocessing instance postproc_instance = Postprocess(results_dir=results_dir, checkpoint=args.checkpoint, mode="test",