diff --git a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh b/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh index ed35fd8b68d2d8593c4e9ff411fd0c142b360204..a29b8e1b0f297dc986c5633a45857c590d37b514 100644 --- a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh +++ b/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh @@ -39,6 +39,7 @@ fi source_dir=/my/source/dir/ checkpoint_dir=/my/trained/model/dir results_dir=/my/results/dir +lquick="" # name of model model=convLSTM @@ -46,5 +47,5 @@ model=convLSTM # run postprocessing/generation of model results including evaluation metrics srun python -u ../main_scripts/main_visualize_postprocess.py --checkpoint ${checkpoint_dir} --mode test \ --results_dir ${results_dir} --batch_size 4 \ - --num_stochastic_samples 1 \ + --num_stochastic_samples 1 ${lquick} \ > postprocess_era5-out_all.${SLURM_JOB_ID} diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 4cd800f880b927b84ac8d7a0a929c78bac777a94..5d2528682300420d8fee022782e470786662c556 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -30,10 +30,12 @@ from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, crea class Postprocess(TrainModel): - def __init__(self, results_dir: str = None, checkpoint: str= None, mode: str = "test", batch_size: int = None, + def __init__(self, results_dir: str = None, checkpoint: str = None, mode: str = "test", batch_size: int = None, num_stochastic_samples: int = 1, stochastic_plot_id: int = 0, gpu_mem_frac: float = None, seed: int = None, channel: int = 0, args=None, run_mode: str = "deterministic", - eval_metrics: List = ("mse", "psnr", "ssim","acc"), clim_path: str ="/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly"): + eval_metrics: List = ("mse", "psnr", "ssim", "acc"), + clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly", + lquick: bool = None): """ Initialization of the class instance for postprocessing (generation of forecasts from trained model + basic evauation). @@ -50,7 +52,8 @@ class Postprocess(TrainModel): :param args: namespace of parsed arguments :param run_mode: "deterministic" or "stochastic", default: "deterministic", "stochastic is not supported yet!!! :param eval_metrics: metrics used to evaluate the trained model - :param clim_path: the path to the climatology nc file + :param clim_path: the path to the netCDF-file storing climatolgical data + :param lquick: flag for quick evaluation """ # copy over attributes from parsed argument self.results_dir = self.output_dir = os.path.normpath(results_dir) @@ -63,11 +66,14 @@ class Postprocess(TrainModel): self.stochastic_plot_id = stochastic_plot_id self.args = args self.checkpoint = checkpoint + if not os.path.isfile(self.checkpoint+".meta"): + _ = check_dir(self.checkpoint) + self.checkpoint += "/" # trick to handle checkpoint-directory and file simulataneously self.clim_path = clim_path - _ = check_dir(self.checkpoint) self.run_mode = run_mode self.mode = mode self.channel = channel + self.lquick = lquick # Attributes set during runtime self.norm_cls = None # configuration of basic evaluation @@ -82,7 +88,7 @@ class Postprocess(TrainModel): self.model_hparams_dict_load = self.get_model_hparams_dict() # set input paths and forecast product dictionary self.input_dir, self.input_dir_pkl = self.get_input_dirs() - self.fcst_products = {"persistence": "pfcst", self.model: "mfcst"} + self.fcst_products = {self.model: "mfcst"} if lquick else {"persistence": "pfcst", self.model: "mfcst"} # correct number of stochastic samples if necessary self.check_num_stochastic_samples() # get metadata @@ -102,8 +108,10 @@ class Postprocess(TrainModel): self.setup_model(mode=self.mode) self.setup_graph() self.setup_gpu_config() - self.load_climdata() - + if "acc" in eval_metrics: + self.load_climdata() + else: + self.data_clim = None # Methods that are called during initialization def get_input_dirs(self): @@ -141,10 +149,11 @@ class Postprocess(TrainModel): method_name = Postprocess.copy_data_model_json.__name__ # correctness of self.checkpoint and self.results_dir is already checked in __init__ - model_opt_js = os.path.join(self.checkpoint, "options.json") - model_ds_js = os.path.join(self.checkpoint, "dataset_hparams.json") - model_hp_js = os.path.join(self.checkpoint, "model_hparams.json") - model_dd_js = os.path.join(self.checkpoint, "data_split.json") + checkpoint_dir = os.path.dirname(self.checkpoint) + model_opt_js = os.path.join(checkpoint_dir, "options.json") + model_ds_js = os.path.join(checkpoint_dir, "dataset_hparams.json") + model_hp_js = os.path.join(checkpoint_dir, "model_hparams.json") + model_dd_js = os.path.join(checkpoint_dir, "data_split.json") if os.path.isfile(model_opt_js): shutil.copy(model_opt_js, os.path.join(self.results_dir, "options_checkpoints.json")) @@ -521,6 +530,8 @@ class Postprocess(TrainModel): # get normalized and denormalized input data input_results, input_images_denorm, t_starts = self.get_input_data_per_batch(self.inputs) # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel] + print("%{0}: Start generating {1:d} predictions at current sample index {2:d}".format(method, self.batch_size, + sample_ind)) feed_dict = {input_ph: input_results[name] for name, input_ph in self.inputs.items()} gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict) @@ -536,7 +547,9 @@ class Postprocess(TrainModel): nbs = np.minimum(self.batch_size, self.num_samples_per_epoch - sample_ind) batch_ds = batch_ds.isel(init_time=slice(0, nbs)) - for i in np.arange(nbs): + # run over mini-batch only if quick evaluation is NOT active + for i in np.arange(0 if self.lquick else nbs): + print("%{0}: Process mini-batch sample {1:d}/{2:d}".format(method, i+1, nbs)) # work-around to make use of get_persistence_forecast_per_sample-method times_seq = (pd.date_range(times_0[i], periods=int(self.sequence_length), freq="h")).to_pydatetime() # get persistence forecast for sequences at hand and write to dataset @@ -558,14 +571,15 @@ class Postprocess(TrainModel): # write evaluation metric to corresponding dataset and sa eval_metric_ds = self.populate_eval_metric_ds(eval_metric_ds, batch_ds, sample_ind, self.vars_in[self.channel]) - cond_quantiple_ds = Postprocess.append_ds(batch_ds, cond_quantiple_ds, self.cond_quantile_vars, "init_time",dtype=np.float16) + if not self.lquick: # conditional quantiles are not evaluated for quick evaluation + cond_quantiple_ds = Postprocess.append_ds(batch_ds, cond_quantiple_ds, self.cond_quantile_vars, + "init_time", dtype=np.float16) # ... and increment sample_ind sample_ind += self.batch_size # end of while-loop for samples # safe dataset with evaluation metrics for later use self.eval_metrics_ds = eval_metric_ds self.cond_quantiple_ds = cond_quantiple_ds - #self.add_ensemble_dim() # all methods of the run factory def init_session(self): @@ -1045,7 +1059,6 @@ class Postprocess(TrainModel): # Retrieve starting index ind_first_m = list(time_pickle_first).index(np.array(t_persistence_first_m[0])) - # print("time_pickle_second:", time_pickle_second) ind_second_m = list(time_pickle_second).index(np.array(t_persistence_second_m[0])) # append the sequence of the second month to the first month @@ -1207,36 +1220,54 @@ class Postprocess(TrainModel): def main(): parser = argparse.ArgumentParser() parser.add_argument("--results_dir", type=str, default='results', - help="ignored if output_gif_dir is specified") - parser.add_argument("--checkpoint", - help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + help="Directory to save the results") + parser.add_argument("--checkpoint", help="Directory with checkpoint or checkpoint name (e.g. ${dir}/model-2000)") parser.add_argument("--mode", type=str, choices=['train', 'val', 'test'], default='test', help='mode for dataset, val or test.') parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") parser.add_argument("--num_stochastic_samples", type=int, default=1) parser.add_argument("--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int, default=7) - parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", nargs="+", default=("mse", "psnr", "ssim","acc"), + parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", nargs="+", + default=("mse", "psnr", "ssim", "acc"), help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.") parser.add_argument("--channel", "-channel", dest="channel", type=int, default=0, help="Channel which is used for evaluation.") + parser.add_argument("--lquick_evaluation", "-lquick", dest="lquick", default=False, action="store_true", + help="Flag if (reduced) quick evaluation based on MSE is performed.") + parser.add_argument("--evaluation_metric_quick", "-metric_quick", dest="metric_quick", type=str, default="mse", + help="(Only) metric to evaluate when quick evaluation (-lquick) is chosen.") args = parser.parse_args() + method = os.path.basename(__file__) + print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs(): print(k, "=", v) print('------------------------------------- End --------------------------------------') + eval_metrics = args.eval_metrics + results_dir = args.results_dir + if args.lquick: # in case of quick evaluation, onyl evaluate MSE and modify results_dir + eval_metrics = [args.metric_quick] + if not os.path.isfile(args.checkpoint+".meta"): + raise ValueError("%{0}: Pass a specific checkpoint-file for quick evaluation.".format(method)) + chp = os.path.basename(args.checkpoint) + results_dir = args.results_dir + "_{0}".format(chp) + print("%{0}: Quick evaluation is chosen. \n * evaluation metric: {1}\n".format(method, args.metric_quick) + + "* checkpointed model: {0}\n * no conditional quantile and forecast example plots".format(chp)) + # initialize postprocessing instance - postproc_instance = Postprocess(results_dir=args.results_dir, checkpoint=args.checkpoint, mode="test", + postproc_instance = Postprocess(results_dir=results_dir, checkpoint=args.checkpoint, mode="test", batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples, gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, - eval_metrics=args.eval_metrics, channel=args.channel) + eval_metrics=eval_metrics, channel=args.channel, lquick=args.lquick) # run the postprocessing postproc_instance.run() postproc_instance.handle_eval_metrics() - postproc_instance.plot_example_forecasts(metric=args.eval_metrics[0], channel=args.channel) - postproc_instance.plot_conditional_quantiles() + if not args.lquick: # don't produce additional plots in case of quick evaluation + postproc_instance.plot_example_forecasts(metric=args.eval_metrics[0], channel=args.channel) + postproc_instance.plot_conditional_quantiles() if __name__ == '__main__': diff --git a/video_prediction_tools/model_modules/video_prediction/models/base_model.py b/video_prediction_tools/model_modules/video_prediction/models/base_model.py index 41f4bb3d1eb59c7e51624c03bfb8a582cb38b1e2..81b608495aac53615cf24f62ca4375702091bd75 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/base_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/base_model.py @@ -230,6 +230,8 @@ class BaseVideoPredictionModel(object): return eval_outputs, eval_metrics def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): + + method = BaseVideoPredictionModel.restore.__name__ if checkpoints: var_list = self.saveable_variables # possibly restore from multiple checkpoints. useful if subset of weights @@ -240,7 +242,7 @@ class BaseVideoPredictionModel(object): skip_global_step = len(checkpoints) > 1 savers = [] for checkpoint in checkpoints: - print("creating restore saver from checkpoint %s" % checkpoint) + print("%{0}: Creating restore saver from checkpoint '{1}'".format(method, checkpoint)) saver, _ = tf_utils.get_checkpoint_restore_saver( checkpoint, var_list, skip_global_step=skip_global_step, restore_to_checkpoint_mapping=restore_to_checkpoint_mapping) diff --git a/video_prediction_tools/model_modules/video_prediction/models/savp_model.py b/video_prediction_tools/model_modules/video_prediction/models/savp_model.py index e662de77f0cdd456e5d8d5a3927f1134abef128a..5b817106a9f4e8e4a91c7c6b7027765636411e94 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/savp_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/savp_model.py @@ -686,10 +686,8 @@ class SAVPCell(tf.nn.rnn_cell.RNNCell): def generator_given_z_fn(inputs, mode, hparams): # all the inputs needs to have the same length for unrolling the rnn - print("inputs.items",inputs.items()) #20200822 bing inputs ={"images":inputs["images"]} - print("inputs 20200822:",inputs) #20200822 inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) for name, input in inputs.items()} diff --git a/video_prediction_tools/postprocess/statistical_evaluation.py b/video_prediction_tools/postprocess/statistical_evaluation.py index 7ed9091d2c0a1684c579d891d057d46fee94eb1d..96cf4a47070bd6dc36f9d5cf51c577271385a7e1 100644 --- a/video_prediction_tools/postprocess/statistical_evaluation.py +++ b/video_prediction_tools/postprocess/statistical_evaluation.py @@ -202,7 +202,7 @@ class Scores: Class to calculate scores and skill scores. """ - known_scores = ["mse", "psnr","ssim", "acc"] + known_scores = ["mse", "psnr", "ssim", "acc"] def __init__(self, score_name: str, dims: List[str]): """ @@ -220,26 +220,6 @@ class Scores: # attributes set when run_calculation is called self.avg_dims = dims - # ML 2021-06-10: The following method is not runnable and yet, it is unclear if it is needed at all. - # Thus, it is commented out for potential later use (in case that it won't be discarded). - # def run_calculation(self, model_data, ref_data, dims2avg=None, **kwargs): - # - # method = Scores.run_calculation.__name__ - # - # model_data, ref_data = Scores.set_model_and_ref_data(model_data, ref_data, dims2avg=dims2avg) - # - # try: - # # if self.avg_dims is None: - # result = self.score_func(model_data, ref_data, **kwargs) - # # else: - # # result = self.score_func(model_data, ref_data, **kwargs) - # except Exception as err: - # print("%{0}: Calculation of '{1}' was not successful. Inspect error message!".format(method, - # self.score_name)) - # raise err - # - # return result - def set_score_name(self, score_name): method = Scores.set_score_name.__name__ diff --git a/video_prediction_tools/utils/runscript_generator/config_postprocess.py b/video_prediction_tools/utils/runscript_generator/config_postprocess.py index dd9dc349a3fb1c6bc1af63c179ce850efa3b15ef..a821ce7b5ce7430fac3162a6969d1edab8b8de39 100755 --- a/video_prediction_tools/utils/runscript_generator/config_postprocess.py +++ b/video_prediction_tools/utils/runscript_generator/config_postprocess.py @@ -29,8 +29,9 @@ class Config_Postprocess(Config_runscript_base): self.model = None self.checkpoint_dir = None self.results_dir = None + self.lquick = None # list of variables to be written to runscript - self.list_batch_vars = ["VIRT_ENV_NAME", "results_dir", "checkpoint_dir", "model"] + self.list_batch_vars = ["VIRT_ENV_NAME", "results_dir", "checkpoint_dir", "model", "lquick"] # copy over method for keyboard interaction self.run_config = Config_Postprocess.run_postprocess # @@ -87,8 +88,22 @@ class Config_Postprocess(Config_runscript_base): base_dir, exp_dir_base, exp_dir = "/"+os.path.join(*cp_dir_split[:-4]), cp_dir_split[-3], cp_dir_split[-1] self.runscript_target = self.rscrpt_tmpl_prefix + self.dataset + "_" + exp_dir + ".sh" - # finally, set results_dir + # Set results_dir self.results_dir = os.path.join(base_dir, "results", exp_dir_base, self.model, exp_dir) + + return + # Decide if quick evaluation should be performed + quick_req_str = "Should a reduced, quick evalutaion be performed (yes/no):" + quick_err = ValueError("Pass either True or False") + + self.lquick = Config_Postprocess.keyboard_interaction(quick_req_str, Config_Postprocess.check_quick, + quick_err, ntries=2) + if self.lquick.lower() == "yes": + self.lquick = "-lquick" + ### TO BE ADDDED ### + # * SELECTION OF EVALUATED MODEL + else: + self.lquick = " " # # ----------------------------------------------------------------------------------- # @@ -194,3 +209,23 @@ class Config_Postprocess(Config_runscript_base): else: if not silent: print("Passed directory '{0}' does not exist!".format(checkpoint_dir)) return status + # + # ----------------------------------------------------------------------------------- + # + @staticmethod + def check_quick(decision, silent=False): + """ + Checks if decision corresponds either to yes or no (both can be capitalized) + :param decision: yes-/no-decision from keyboard interaction + :param silent: flag if print-statement are executed + :return: status with True confirming success + """ + status = False + valid_decisions = ["yes", "no"] + if decision.lower() in valid_decisions: + status = True + else: + if not silent: + print("{0} is not a valid entry. Pass either yes or no.".format(str(decision)) + + return status \ No newline at end of file