diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 9e58de96a31913eb19678e151fac5c46d6e80409..b16e33919d9f335e0d2b45ad3309ad901e568f57 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -562,7 +562,8 @@ class BestModelSelector(object): Class to select the best performing model from multiple checkpoints created during training """ - def __init__(self, model_dir: str, eval_metric: str, criterion: str = "min", channel: int = 0, seed: int = 42): + def __init__(self, model_dir: str, eval_metric: str, ltest: bool, criterion: str = "min", channel: int = 0, + seed: int = 42): """ Class to retrieve the best model checkpoint. The last one is also retained. :param model_dir: path to directory where checkpoints are saved (the trained model output directory) @@ -570,6 +571,7 @@ class BestModelSelector(object): :param criterion: set to 'min' ('max') for negatively (positively) oriented metrics :param channel: channel of data used for selection :param seed: seed for the Postprocess-instance + :param ltest: flag to allow bootstrapping in Postprocessing on tiny datasets """ method = self.__class__.__name__ # sanity check @@ -581,6 +583,7 @@ class BestModelSelector(object): self.channel = channel self.metric = eval_metric self.checkpoint_base_dir = model_dir + self.ltest = ltest self.checkpoints_all = BestModelSelector.get_checkpoints_dirs(model_dir) self.ncheckpoints = len(self.checkpoints_all) # evaluate all checkpoints... @@ -604,7 +607,7 @@ class BestModelSelector(object): results_dir_eager = os.path.join(checkpoint, "results_eager") eager_eval = Postprocess(results_dir=results_dir_eager, checkpoint=checkpoint, data_mode="val", batch_size=32, seed=self.seed, eval_metrics=[eval_metric], channel=self.channel, frac_data=0.33, - lquick=True) + lquick=True, ltest=self.ltest) eager_eval.run() eager_eval.handle_eval_metrics() @@ -728,6 +731,8 @@ def main(): parser.add_argument("--frac_intv_save", type=float, default=0.01, help="Fraction of all iteration steps to define the saving interval.") parser.add_argument("--seed", default=1234, type=int) + parser.add_argument("--test_mode", "-test", dest="test_mode", default=False, action="store_true", + help="Test mode for postprocessing to allow bootstrapping on small datasets.") args = parser.parse_args() # start timing for the whole run @@ -753,7 +758,7 @@ def main(): # select best model if args.dataset == "era5" and args.frac_start_save < 1.: - _ = BestModelSelector(args.output_dir, "mse") + _ = BestModelSelector(args.output_dir, "mse", args.test_mode) timeit_finish = time.time() print("Selecting the best model checkpoint took {0:.2f} minutes.".format((timeit_finish - timeit_after_train)/60.)) else: diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index dffbc898d14b3df19724c8fdcb0ff4a01e515f41..0c9f8e434b9c1706b55a7ba6a8a99b3a7156e628 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -37,8 +37,9 @@ class Postprocess(TrainModel): def __init__(self, results_dir: str = None, checkpoint: str = None, data_mode: str = "test", batch_size: int = None, gpu_mem_frac: float = None, num_stochastic_samples: int = 1, stochastic_plot_id: int = 0, seed: int = None, channel: int = 0, run_mode: str = "deterministic", lquick: bool = None, - frac_data: float = 1., eval_metrics: List = ("mse", "psnr", "ssim", "acc"), args=None, - clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/climatology_t2m_1991-2020.nc"): + frac_data: float = 1., eval_metrics: List = ("mse", "psnr", "ssim", "acc"), ltest=False, + clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/"+ + "climatology_t2m_1991-2020.nc", args=None): """ Initialization of the class instance for postprocessing (generation of forecasts from trained model + basic evauation). @@ -56,6 +57,7 @@ class Postprocess(TrainModel): :param lquick: flag for quick evaluation :param frac_data: fraction of dataset to be used for evaluation (only applied when shuffling is active) :param eval_metrics: metrics used to evaluate the trained model + :param ltest: flag for test mode to allow bootstrapping on tiny datasets :param clim_path: the path to the netCDF-file storing climatolgical data :param args: namespace of parsed arguments """ @@ -86,6 +88,7 @@ class Postprocess(TrainModel): self.eval_metrics = eval_metrics self.nboots_block = 1000 self.block_length = 7 * 24 # this corresponds to a block length of 7 days in case of hourly forecasts + if ltest: self.block_length = 1 # initialize evrything to get an executable Postprocess instance if args is not None: self.save_args_to_option_json() # create options.json in results directory @@ -1265,8 +1268,10 @@ def main(): help="(Only) metric to evaluate when quick evaluation (-lquick) is chosen.") parser.add_argument("--climatology_file", "-clim_fl", dest="clim_fl", type=str, default=False, help="The path to the climatology_t2m_1991-2020.nc file ") - parse.add_argument("--frac_data", "-f_dt", dest="f_dt",type=float,default=1, - help="fraction of dataset to be used for evaluation (only applied when shuffling is active)") + parser.add_argument("--frac_data", "-f_dt", dest="f_dt", type=float, default=1., + help="Fraction of dataset to be used for evaluation (only applied when shuffling is active).") + parser.add_argument("--test_mode", "-test", dest="test_mode", default=False, action="store_true", + help="Test mode for postprocessing to allow bootstrapping on small datasets.") args = parser.parse_args() method = os.path.basename(__file__) @@ -1293,7 +1298,7 @@ def main(): batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples, gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, eval_metrics=eval_metrics, channel=args.channel, lquick=args.lquick, - clim_path=args.clim_fl,frac_data=args.frac_data) + clim_path=args.clim_fl,frac_data=args.frac_data, ltest=args.test_mode) # run the postprocessing postproc_instance.run() postproc_instance.handle_eval_metrics()