Skip to content
Snippets Groups Projects
Commit b4240387 authored by Michael Langguth's avatar Michael Langguth
Browse files

Introcude lquick-flag to postprocessing which allows reduced, but accelerated evaluation.

parent a8c32886
No related branches found
No related tags found
No related merge requests found
Pipeline #76828 passed
...@@ -33,7 +33,9 @@ class Postprocess(TrainModel): ...@@ -33,7 +33,9 @@ class Postprocess(TrainModel):
def __init__(self, results_dir: str = None, checkpoint: str = None, mode: str = "test", batch_size: int = None, def __init__(self, results_dir: str = None, checkpoint: str = None, mode: str = "test", batch_size: int = None,
num_stochastic_samples: int = 1, stochastic_plot_id: int = 0, gpu_mem_frac: float = None, num_stochastic_samples: int = 1, stochastic_plot_id: int = 0, gpu_mem_frac: float = None,
seed: int = None, channel: int = 0, args=None, run_mode: str = "deterministic", seed: int = None, channel: int = 0, args=None, run_mode: str = "deterministic",
eval_metrics: List = ("mse", "psnr", "ssim","acc"), clim_path: str ="/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly"): eval_metrics: List = ("mse", "psnr", "ssim", "acc"),
clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly",
lquick: bool = None):
""" """
Initialization of the class instance for postprocessing (generation of forecasts from trained model + Initialization of the class instance for postprocessing (generation of forecasts from trained model +
basic evauation). basic evauation).
...@@ -50,7 +52,8 @@ class Postprocess(TrainModel): ...@@ -50,7 +52,8 @@ class Postprocess(TrainModel):
:param args: namespace of parsed arguments :param args: namespace of parsed arguments
:param run_mode: "deterministic" or "stochastic", default: "deterministic", "stochastic is not supported yet!!! :param run_mode: "deterministic" or "stochastic", default: "deterministic", "stochastic is not supported yet!!!
:param eval_metrics: metrics used to evaluate the trained model :param eval_metrics: metrics used to evaluate the trained model
:param clim_path: the path to the climatology nc file :param clim_path: the path to the netCDF-file storing climatolgical data
:param lquick: flag for quick evaluation
""" """
# copy over attributes from parsed argument # copy over attributes from parsed argument
self.results_dir = self.output_dir = os.path.normpath(results_dir) self.results_dir = self.output_dir = os.path.normpath(results_dir)
...@@ -68,6 +71,7 @@ class Postprocess(TrainModel): ...@@ -68,6 +71,7 @@ class Postprocess(TrainModel):
self.run_mode = run_mode self.run_mode = run_mode
self.mode = mode self.mode = mode
self.channel = channel self.channel = channel
self.lquick = lquick
# Attributes set during runtime # Attributes set during runtime
self.norm_cls = None self.norm_cls = None
# configuration of basic evaluation # configuration of basic evaluation
...@@ -82,7 +86,7 @@ class Postprocess(TrainModel): ...@@ -82,7 +86,7 @@ class Postprocess(TrainModel):
self.model_hparams_dict_load = self.get_model_hparams_dict() self.model_hparams_dict_load = self.get_model_hparams_dict()
# set input paths and forecast product dictionary # set input paths and forecast product dictionary
self.input_dir, self.input_dir_pkl = self.get_input_dirs() self.input_dir, self.input_dir_pkl = self.get_input_dirs()
self.fcst_products = {"persistence": "pfcst", self.model: "mfcst"} self.fcst_products = {self.model: "mfcst"} if lquick else {"persistence": "pfcst", self.model: "mfcst"}
# correct number of stochastic samples if necessary # correct number of stochastic samples if necessary
self.check_num_stochastic_samples() self.check_num_stochastic_samples()
# get metadata # get metadata
...@@ -102,9 +106,9 @@ class Postprocess(TrainModel): ...@@ -102,9 +106,9 @@ class Postprocess(TrainModel):
self.setup_model(mode=self.mode) self.setup_model(mode=self.mode)
self.setup_graph() self.setup_graph()
self.setup_gpu_config() self.setup_gpu_config()
if "acc" in eval_metrics:
self.load_climdata() self.load_climdata()
# Methods that are called during initialization # Methods that are called during initialization
def get_input_dirs(self): def get_input_dirs(self):
""" """
...@@ -551,21 +555,24 @@ class Postprocess(TrainModel): ...@@ -551,21 +555,24 @@ class Postprocess(TrainModel):
if os.path.exists(nc_fname): if os.path.exists(nc_fname):
print("%{0}: The file '{1}' already exists and is therefore skipped".format(method, nc_fname)) print("%{0}: The file '{1}' already exists and is therefore skipped".format(method, nc_fname))
else: elif not self.lquick:
self.save_ds_to_netcdf(batch_ds.isel(init_time=i), nc_fname) self.save_ds_to_netcdf(batch_ds.isel(init_time=i), nc_fname)
else:
pass
# end of batch-loop # end of batch-loop
# write evaluation metric to corresponding dataset and sa # write evaluation metric to corresponding dataset and sa
eval_metric_ds = self.populate_eval_metric_ds(eval_metric_ds, batch_ds, sample_ind, eval_metric_ds = self.populate_eval_metric_ds(eval_metric_ds, batch_ds, sample_ind,
self.vars_in[self.channel]) self.vars_in[self.channel])
cond_quantiple_ds = Postprocess.append_ds(batch_ds, cond_quantiple_ds, self.cond_quantile_vars, "init_time",dtype=np.float16) if not self.lquick: # conditional quantiles are not evaluated for quick evaluation
cond_quantiple_ds = Postprocess.append_ds(batch_ds, cond_quantiple_ds, self.cond_quantile_vars,
"init_time", dtype=np.float16)
# ... and increment sample_ind # ... and increment sample_ind
sample_ind += self.batch_size sample_ind += self.batch_size
# end of while-loop for samples # end of while-loop for samples
# safe dataset with evaluation metrics for later use # safe dataset with evaluation metrics for later use
self.eval_metrics_ds = eval_metric_ds self.eval_metrics_ds = eval_metric_ds
self.cond_quantiple_ds = cond_quantiple_ds self.cond_quantiple_ds = cond_quantiple_ds
#self.add_ensemble_dim()
# all methods of the run factory # all methods of the run factory
def init_session(self): def init_session(self):
...@@ -1207,19 +1214,21 @@ class Postprocess(TrainModel): ...@@ -1207,19 +1214,21 @@ class Postprocess(TrainModel):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--results_dir", type=str, default='results', parser.add_argument("--results_dir", type=str, default='results',
help="ignored if output_gif_dir is specified") help="Directory to save the results")
parser.add_argument("--checkpoint", parser.add_argument("--checkpoint", help="Directory with checkpoint or checkpoint name (e.g. ${dir}/model-2000)")
help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
parser.add_argument("--mode", type=str, choices=['train', 'val', 'test'], default='test', parser.add_argument("--mode", type=str, choices=['train', 'val', 'test'], default='test',
help='mode for dataset, val or test.') help='mode for dataset, val or test.')
parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
parser.add_argument("--num_stochastic_samples", type=int, default=1) parser.add_argument("--num_stochastic_samples", type=int, default=1)
parser.add_argument("--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use") parser.add_argument("--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use")
parser.add_argument("--seed", type=int, default=7) parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", nargs="+", default=("mse", "psnr", "ssim","acc"), parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", nargs="+",
default=("mse", "psnr", "ssim", "acc"),
help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.") help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.")
parser.add_argument("--channel", "-channel", dest="channel", type=int, default=0, parser.add_argument("--channel", "-channel", dest="channel", type=int, default=0,
help="Channel which is used for evaluation.") help="Channel which is used for evaluation.")
parser.add_argument("--lquick_evaluation", "-lquick", dest="lquick", default=False, action="store_true",
help="Flag if (reduced) quick evaluation based on MSE is performed.")
args = parser.parse_args() args = parser.parse_args()
print('----------------------------------- Options ------------------------------------') print('----------------------------------- Options ------------------------------------')
...@@ -1227,14 +1236,23 @@ def main(): ...@@ -1227,14 +1236,23 @@ def main():
print(k, "=", v) print(k, "=", v)
print('------------------------------------- End --------------------------------------') print('------------------------------------- End --------------------------------------')
eval_metrics = args.eval_metrics
results_dir = args.results_dir
if args.lquick: # in case of quick evaluation, onyl evaluate MSE and modify results_dir
eval_metrics = ["mse"]
if not os.path.isfile(args.checkpoint):
raise ValueError("Pass a specific checkpoint-file for quick evaluation.")
results_dir = args.results_dir + "_{0}".format(os.path.basename(args.checkpoint))
# initialize postprocessing instance # initialize postprocessing instance
postproc_instance = Postprocess(results_dir=args.results_dir, checkpoint=args.checkpoint, mode="test", postproc_instance = Postprocess(results_dir=results_dir, checkpoint=args.checkpoint, mode="test",
batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples, batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples,
gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args,
eval_metrics=args.eval_metrics, channel=args.channel) eval_metrics=eval_metrics, channel=args.channel, lquick=args.lquick)
# run the postprocessing # run the postprocessing
postproc_instance.run() postproc_instance.run()
postproc_instance.handle_eval_metrics() postproc_instance.handle_eval_metrics()
if not args.lquick: # don't produce additional plots in case of quick evaluation
postproc_instance.plot_example_forecasts(metric=args.eval_metrics[0], channel=args.channel) postproc_instance.plot_example_forecasts(metric=args.eval_metrics[0], channel=args.channel)
postproc_instance.plot_conditional_quantiles() postproc_instance.plot_conditional_quantiles()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment