Skip to content
Snippets Groups Projects
Commit b80c3305 authored by Michael Langguth's avatar Michael Langguth
Browse files

Adapt postprocessing and training to allow for bootstrapping on tiny datasets.

parent c8adb6bc
No related branches found
No related tags found
No related merge requests found
...@@ -562,7 +562,8 @@ class BestModelSelector(object): ...@@ -562,7 +562,8 @@ class BestModelSelector(object):
Class to select the best performing model from multiple checkpoints created during training Class to select the best performing model from multiple checkpoints created during training
""" """
def __init__(self, model_dir: str, eval_metric: str, criterion: str = "min", channel: int = 0, seed: int = 42): def __init__(self, model_dir: str, eval_metric: str, ltest: bool, criterion: str = "min", channel: int = 0,
seed: int = 42):
""" """
Class to retrieve the best model checkpoint. The last one is also retained. Class to retrieve the best model checkpoint. The last one is also retained.
:param model_dir: path to directory where checkpoints are saved (the trained model output directory) :param model_dir: path to directory where checkpoints are saved (the trained model output directory)
...@@ -570,6 +571,7 @@ class BestModelSelector(object): ...@@ -570,6 +571,7 @@ class BestModelSelector(object):
:param criterion: set to 'min' ('max') for negatively (positively) oriented metrics :param criterion: set to 'min' ('max') for negatively (positively) oriented metrics
:param channel: channel of data used for selection :param channel: channel of data used for selection
:param seed: seed for the Postprocess-instance :param seed: seed for the Postprocess-instance
:param ltest: flag to allow bootstrapping in Postprocessing on tiny datasets
""" """
method = self.__class__.__name__ method = self.__class__.__name__
# sanity check # sanity check
...@@ -581,6 +583,7 @@ class BestModelSelector(object): ...@@ -581,6 +583,7 @@ class BestModelSelector(object):
self.channel = channel self.channel = channel
self.metric = eval_metric self.metric = eval_metric
self.checkpoint_base_dir = model_dir self.checkpoint_base_dir = model_dir
self.ltest = ltest
self.checkpoints_all = BestModelSelector.get_checkpoints_dirs(model_dir) self.checkpoints_all = BestModelSelector.get_checkpoints_dirs(model_dir)
self.ncheckpoints = len(self.checkpoints_all) self.ncheckpoints = len(self.checkpoints_all)
# evaluate all checkpoints... # evaluate all checkpoints...
...@@ -604,7 +607,7 @@ class BestModelSelector(object): ...@@ -604,7 +607,7 @@ class BestModelSelector(object):
results_dir_eager = os.path.join(checkpoint, "results_eager") results_dir_eager = os.path.join(checkpoint, "results_eager")
eager_eval = Postprocess(results_dir=results_dir_eager, checkpoint=checkpoint, data_mode="val", batch_size=32, eager_eval = Postprocess(results_dir=results_dir_eager, checkpoint=checkpoint, data_mode="val", batch_size=32,
seed=self.seed, eval_metrics=[eval_metric], channel=self.channel, frac_data=0.33, seed=self.seed, eval_metrics=[eval_metric], channel=self.channel, frac_data=0.33,
lquick=True) lquick=True, ltest=self.ltest)
eager_eval.run() eager_eval.run()
eager_eval.handle_eval_metrics() eager_eval.handle_eval_metrics()
...@@ -728,6 +731,8 @@ def main(): ...@@ -728,6 +731,8 @@ def main():
parser.add_argument("--frac_intv_save", type=float, default=0.01, parser.add_argument("--frac_intv_save", type=float, default=0.01,
help="Fraction of all iteration steps to define the saving interval.") help="Fraction of all iteration steps to define the saving interval.")
parser.add_argument("--seed", default=1234, type=int) parser.add_argument("--seed", default=1234, type=int)
parser.add_argument("--test_mode", "-test", dest="test_mode", default=False, action="store_true",
help="Test mode for postprocessing to allow bootstrapping on small datasets.")
args = parser.parse_args() args = parser.parse_args()
# start timing for the whole run # start timing for the whole run
...@@ -753,7 +758,7 @@ def main(): ...@@ -753,7 +758,7 @@ def main():
# select best model # select best model
if args.dataset == "era5" and args.frac_start_save < 1.: if args.dataset == "era5" and args.frac_start_save < 1.:
_ = BestModelSelector(args.output_dir, "mse") _ = BestModelSelector(args.output_dir, "mse", args.test_mode)
timeit_finish = time.time() timeit_finish = time.time()
print("Selecting the best model checkpoint took {0:.2f} minutes.".format((timeit_finish - timeit_after_train)/60.)) print("Selecting the best model checkpoint took {0:.2f} minutes.".format((timeit_finish - timeit_after_train)/60.))
else: else:
......
...@@ -37,8 +37,9 @@ class Postprocess(TrainModel): ...@@ -37,8 +37,9 @@ class Postprocess(TrainModel):
def __init__(self, results_dir: str = None, checkpoint: str = None, data_mode: str = "test", batch_size: int = None, def __init__(self, results_dir: str = None, checkpoint: str = None, data_mode: str = "test", batch_size: int = None,
gpu_mem_frac: float = None, num_stochastic_samples: int = 1, stochastic_plot_id: int = 0, gpu_mem_frac: float = None, num_stochastic_samples: int = 1, stochastic_plot_id: int = 0,
seed: int = None, channel: int = 0, run_mode: str = "deterministic", lquick: bool = None, seed: int = None, channel: int = 0, run_mode: str = "deterministic", lquick: bool = None,
frac_data: float = 1., eval_metrics: List = ("mse", "psnr", "ssim", "acc"), args=None, frac_data: float = 1., eval_metrics: List = ("mse", "psnr", "ssim", "acc"), ltest=False,
clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/climatology_t2m_1991-2020.nc"): clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/"+
"climatology_t2m_1991-2020.nc", args=None):
""" """
Initialization of the class instance for postprocessing (generation of forecasts from trained model + Initialization of the class instance for postprocessing (generation of forecasts from trained model +
basic evauation). basic evauation).
...@@ -56,6 +57,7 @@ class Postprocess(TrainModel): ...@@ -56,6 +57,7 @@ class Postprocess(TrainModel):
:param lquick: flag for quick evaluation :param lquick: flag for quick evaluation
:param frac_data: fraction of dataset to be used for evaluation (only applied when shuffling is active) :param frac_data: fraction of dataset to be used for evaluation (only applied when shuffling is active)
:param eval_metrics: metrics used to evaluate the trained model :param eval_metrics: metrics used to evaluate the trained model
:param ltest: flag for test mode to allow bootstrapping on tiny datasets
:param clim_path: the path to the netCDF-file storing climatolgical data :param clim_path: the path to the netCDF-file storing climatolgical data
:param args: namespace of parsed arguments :param args: namespace of parsed arguments
""" """
...@@ -86,6 +88,7 @@ class Postprocess(TrainModel): ...@@ -86,6 +88,7 @@ class Postprocess(TrainModel):
self.eval_metrics = eval_metrics self.eval_metrics = eval_metrics
self.nboots_block = 1000 self.nboots_block = 1000
self.block_length = 7 * 24 # this corresponds to a block length of 7 days in case of hourly forecasts self.block_length = 7 * 24 # this corresponds to a block length of 7 days in case of hourly forecasts
if ltest: self.block_length = 1
# initialize evrything to get an executable Postprocess instance # initialize evrything to get an executable Postprocess instance
if args is not None: if args is not None:
self.save_args_to_option_json() # create options.json in results directory self.save_args_to_option_json() # create options.json in results directory
...@@ -1265,8 +1268,10 @@ def main(): ...@@ -1265,8 +1268,10 @@ def main():
help="(Only) metric to evaluate when quick evaluation (-lquick) is chosen.") help="(Only) metric to evaluate when quick evaluation (-lquick) is chosen.")
parser.add_argument("--climatology_file", "-clim_fl", dest="clim_fl", type=str, default=False, parser.add_argument("--climatology_file", "-clim_fl", dest="clim_fl", type=str, default=False,
help="The path to the climatology_t2m_1991-2020.nc file ") help="The path to the climatology_t2m_1991-2020.nc file ")
parse.add_argument("--frac_data", "-f_dt", dest="f_dt",type=float,default=1, parser.add_argument("--frac_data", "-f_dt", dest="f_dt", type=float, default=1.,
help="fraction of dataset to be used for evaluation (only applied when shuffling is active)") help="Fraction of dataset to be used for evaluation (only applied when shuffling is active).")
parser.add_argument("--test_mode", "-test", dest="test_mode", default=False, action="store_true",
help="Test mode for postprocessing to allow bootstrapping on small datasets.")
args = parser.parse_args() args = parser.parse_args()
method = os.path.basename(__file__) method = os.path.basename(__file__)
...@@ -1293,7 +1298,7 @@ def main(): ...@@ -1293,7 +1298,7 @@ def main():
batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples, batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples,
gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args,
eval_metrics=eval_metrics, channel=args.channel, lquick=args.lquick, eval_metrics=eval_metrics, channel=args.channel, lquick=args.lquick,
clim_path=args.clim_fl,frac_data=args.frac_data) clim_path=args.clim_fl,frac_data=args.frac_data, ltest=args.test_mode)
# run the postprocessing # run the postprocessing
postproc_instance.run() postproc_instance.run()
postproc_instance.handle_eval_metrics() postproc_instance.handle_eval_metrics()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment