diff --git a/test/test_visualize_postprocess.py b/test/test_visualize_postprocess.py index 92191c26e5565f9ff99b7ed29c07631d09431c85..d9a3cdc01bd2f803758d090928fb2b752cae8890 100644 --- a/test/test_visualize_postprocess.py +++ b/test/test_visualize_postprocess.py @@ -56,9 +56,12 @@ def test_get_data_params(vis_case1): assert vis_case1.future_length == 12 def test_run_deterministic(vis_case1): + vis_case1.num_samples_per_epoch = 20 vis_case1.init_session() vis_case1.restore(vis_case1.sess,vis_case1.checkpoint) - vis_case1.sample_ind = 0 + print("fcast-product",vis_case1.fcst_products) + eval_metric_ds = Postprocess.init_metric_ds(vis_case1.fcst_products, vis_case1.eval_metrics, vis_case1.vars_in[vis_case1.channel], vis_case1.num_samples_per_epoch, vis_case1.future_length) + input_results,input_images_denorm_all,t_starts = vis_case1.get_input_data_per_batch(vis_case1.inputs) assert len(t_starts) == batch_size ts_1 = t_starts[0][0] @@ -87,24 +90,49 @@ def test_run_deterministic(vis_case1): times_0, init_times = vis_case1.get_init_time(t_starts) batch_ds = vis_case1.create_dataset(input_images_denorm_all, gen_images_denorm, init_times) nbs = np.minimum(vis_case1.batch_size, vis_case1.num_samples_per_epoch - sample_ind) + times_seq = (pd.date_range(times_0[0], periods=int(vis_case1.sequence_length), freq="h")).to_pydatetime() persistence_seq, _ = Postprocess.get_persistence(times_seq, vis_case1.input_dir_pkl) ts_1_per = (pd.to_datetime(times_0[0]) - datetime.timedelta(hours=23)).strftime("%Y%m%d%H") - + year_per = str(ts_1_per)[:4] month_per = str(ts_1_per)[4:6] filename_per = "ecmwf_era5_" + str(ts_1_per)[2:] + ".nc" fl_per = os.path.join("/p/scratch/deepacf/deeprain/ambs_era5/extractedData",year_per,month_per,filename_per) with Dataset(fl_per,"r") as data_file: - t2_var_per = data_file.variables["2t"][0,117:173,0:92] + t2_var_per = data_file.variables["2t"][0,117:173,0:92] t2_per_var = np.array(t2_var_per) t2_per_max = np.max(t2_per_var) per_image_max = np.max(persistence_seq[0]) assert t2_per_max == per_image_max + + + ##Test evaluation metric + for ivar, var in enumerate(vis_case1.vars_in): + batch_ds["{0}_persistence_fcst".format(var)].loc[dict(init_time=init_times[0])] = \ + persistence_seq[vis_case1.context_frames-1:, :, :, ivar] + + eval_metric_ds = vis_case1.populate_eval_metric_ds(eval_metric_ds,batch_ds,sample_ind,vis_case1.vars_in[vis_case1.channel]) + ##now manuly calculate the mse and see if values is the same as the ones in eval_metric_ds + #calculate the mse between generateed images and reference images + sample_gen = gen_images_denorm[0,vis_case1.context_frames-1:,:,:,vis_case1.channel] + sample_ref = input_images_denorm_all[0,vis_case1.context_frames:,:,:,vis_case1.channel] + sample_gen_ref_mse_t0 = np.mean((sample_gen[0] - sample_ref[0])**2) + metric_name = "2t_savp_mse" + print("eval_metric_ds",eval_metric_ds) + assert eval_metric_ds[metric_name][0,0] == sample_gen_ref_mse_t0 + sample_gen_ref_mse_t5 = np.mean((sample_gen[5] - sample_ref[5])**2) + assert eval_metric_ds[metric_name][0,5] == sample_gen_ref_mse_t5 + + + + + + #def test_run_determinstic_quantile_plot(vis_case1): # vis_case1.init_metric_ds()