diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 4cd800f880b927b84ac8d7a0a929c78bac777a94..774f07650703a6e67bc6adc1df0527e673dc0a37 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -30,10 +30,12 @@ from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, crea class Postprocess(TrainModel): - def __init__(self, results_dir: str = None, checkpoint: str= None, mode: str = "test", batch_size: int = None, + def __init__(self, results_dir: str = None, checkpoint: str = None, mode: str = "test", batch_size: int = None, num_stochastic_samples: int = 1, stochastic_plot_id: int = 0, gpu_mem_frac: float = None, seed: int = None, channel: int = 0, args=None, run_mode: str = "deterministic", - eval_metrics: List = ("mse", "psnr", "ssim","acc"), clim_path: str ="/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly"): + eval_metrics: List = ("mse", "psnr", "ssim", "acc"), + clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly", + lquick: bool = None): """ Initialization of the class instance for postprocessing (generation of forecasts from trained model + basic evauation). @@ -50,7 +52,8 @@ class Postprocess(TrainModel): :param args: namespace of parsed arguments :param run_mode: "deterministic" or "stochastic", default: "deterministic", "stochastic is not supported yet!!! :param eval_metrics: metrics used to evaluate the trained model - :param clim_path: the path to the climatology nc file + :param clim_path: the path to the netCDF-file storing climatolgical data + :param lquick: flag for quick evaluation """ # copy over attributes from parsed argument self.results_dir = self.output_dir = os.path.normpath(results_dir) @@ -68,6 +71,7 @@ class Postprocess(TrainModel): self.run_mode = run_mode self.mode = mode self.channel = channel + self.lquick = lquick # Attributes set during runtime self.norm_cls = None # configuration of basic evaluation @@ -82,7 +86,7 @@ class Postprocess(TrainModel): self.model_hparams_dict_load = self.get_model_hparams_dict() # set input paths and forecast product dictionary self.input_dir, self.input_dir_pkl = self.get_input_dirs() - self.fcst_products = {"persistence": "pfcst", self.model: "mfcst"} + self.fcst_products = {self.model: "mfcst"} if lquick else {"persistence": "pfcst", self.model: "mfcst"} # correct number of stochastic samples if necessary self.check_num_stochastic_samples() # get metadata @@ -102,8 +106,8 @@ class Postprocess(TrainModel): self.setup_model(mode=self.mode) self.setup_graph() self.setup_gpu_config() - self.load_climdata() - + if "acc" in eval_metrics: + self.load_climdata() # Methods that are called during initialization def get_input_dirs(self): @@ -551,21 +555,24 @@ class Postprocess(TrainModel): if os.path.exists(nc_fname): print("%{0}: The file '{1}' already exists and is therefore skipped".format(method, nc_fname)) - else: + elif not self.lquick: self.save_ds_to_netcdf(batch_ds.isel(init_time=i), nc_fname) + else: + pass # end of batch-loop # write evaluation metric to corresponding dataset and sa eval_metric_ds = self.populate_eval_metric_ds(eval_metric_ds, batch_ds, sample_ind, self.vars_in[self.channel]) - cond_quantiple_ds = Postprocess.append_ds(batch_ds, cond_quantiple_ds, self.cond_quantile_vars, "init_time",dtype=np.float16) + if not self.lquick: # conditional quantiles are not evaluated for quick evaluation + cond_quantiple_ds = Postprocess.append_ds(batch_ds, cond_quantiple_ds, self.cond_quantile_vars, + "init_time", dtype=np.float16) # ... and increment sample_ind sample_ind += self.batch_size # end of while-loop for samples # safe dataset with evaluation metrics for later use self.eval_metrics_ds = eval_metric_ds self.cond_quantiple_ds = cond_quantiple_ds - #self.add_ensemble_dim() # all methods of the run factory def init_session(self): @@ -1207,19 +1214,21 @@ class Postprocess(TrainModel): def main(): parser = argparse.ArgumentParser() parser.add_argument("--results_dir", type=str, default='results', - help="ignored if output_gif_dir is specified") - parser.add_argument("--checkpoint", - help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + help="Directory to save the results") + parser.add_argument("--checkpoint", help="Directory with checkpoint or checkpoint name (e.g. ${dir}/model-2000)") parser.add_argument("--mode", type=str, choices=['train', 'val', 'test'], default='test', help='mode for dataset, val or test.') parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") parser.add_argument("--num_stochastic_samples", type=int, default=1) parser.add_argument("--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int, default=7) - parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", nargs="+", default=("mse", "psnr", "ssim","acc"), + parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", nargs="+", + default=("mse", "psnr", "ssim", "acc"), help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.") parser.add_argument("--channel", "-channel", dest="channel", type=int, default=0, help="Channel which is used for evaluation.") + parser.add_argument("--lquick_evaluation", "-lquick", dest="lquick", default=False, action="store_true", + help="Flag if (reduced) quick evaluation based on MSE is performed.") args = parser.parse_args() print('----------------------------------- Options ------------------------------------') @@ -1227,16 +1236,25 @@ def main(): print(k, "=", v) print('------------------------------------- End --------------------------------------') + eval_metrics = args.eval_metrics + results_dir = args.results_dir + if args.lquick: # in case of quick evaluation, onyl evaluate MSE and modify results_dir + eval_metrics = ["mse"] + if not os.path.isfile(args.checkpoint): + raise ValueError("Pass a specific checkpoint-file for quick evaluation.") + results_dir = args.results_dir + "_{0}".format(os.path.basename(args.checkpoint)) + # initialize postprocessing instance - postproc_instance = Postprocess(results_dir=args.results_dir, checkpoint=args.checkpoint, mode="test", + postproc_instance = Postprocess(results_dir=results_dir, checkpoint=args.checkpoint, mode="test", batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples, gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, - eval_metrics=args.eval_metrics, channel=args.channel) + eval_metrics=eval_metrics, channel=args.channel, lquick=args.lquick) # run the postprocessing postproc_instance.run() postproc_instance.handle_eval_metrics() - postproc_instance.plot_example_forecasts(metric=args.eval_metrics[0], channel=args.channel) - postproc_instance.plot_conditional_quantiles() + if not args.lquick: # don't produce additional plots in case of quick evaluation + postproc_instance.plot_example_forecasts(metric=args.eval_metrics[0], channel=args.channel) + postproc_instance.plot_conditional_quantiles() if __name__ == '__main__':