Skip to content
Snippets Groups Projects
Commit d280a888 authored by gong1's avatar gong1
Browse files

add unitest for testing eval_metric_ds

parent aabaf749
Branches
Tags
No related merge requests found
Pipeline #69500 passed
...@@ -56,9 +56,12 @@ def test_get_data_params(vis_case1): ...@@ -56,9 +56,12 @@ def test_get_data_params(vis_case1):
assert vis_case1.future_length == 12 assert vis_case1.future_length == 12
def test_run_deterministic(vis_case1): def test_run_deterministic(vis_case1):
vis_case1.num_samples_per_epoch = 20
vis_case1.init_session() vis_case1.init_session()
vis_case1.restore(vis_case1.sess,vis_case1.checkpoint) vis_case1.restore(vis_case1.sess,vis_case1.checkpoint)
vis_case1.sample_ind = 0 print("fcast-product",vis_case1.fcst_products)
eval_metric_ds = Postprocess.init_metric_ds(vis_case1.fcst_products, vis_case1.eval_metrics, vis_case1.vars_in[vis_case1.channel], vis_case1.num_samples_per_epoch, vis_case1.future_length)
input_results,input_images_denorm_all,t_starts = vis_case1.get_input_data_per_batch(vis_case1.inputs) input_results,input_images_denorm_all,t_starts = vis_case1.get_input_data_per_batch(vis_case1.inputs)
assert len(t_starts) == batch_size assert len(t_starts) == batch_size
ts_1 = t_starts[0][0] ts_1 = t_starts[0][0]
...@@ -87,6 +90,7 @@ def test_run_deterministic(vis_case1): ...@@ -87,6 +90,7 @@ def test_run_deterministic(vis_case1):
times_0, init_times = vis_case1.get_init_time(t_starts) times_0, init_times = vis_case1.get_init_time(t_starts)
batch_ds = vis_case1.create_dataset(input_images_denorm_all, gen_images_denorm, init_times) batch_ds = vis_case1.create_dataset(input_images_denorm_all, gen_images_denorm, init_times)
nbs = np.minimum(vis_case1.batch_size, vis_case1.num_samples_per_epoch - sample_ind) nbs = np.minimum(vis_case1.batch_size, vis_case1.num_samples_per_epoch - sample_ind)
times_seq = (pd.date_range(times_0[0], periods=int(vis_case1.sequence_length), freq="h")).to_pydatetime() times_seq = (pd.date_range(times_0[0], periods=int(vis_case1.sequence_length), freq="h")).to_pydatetime()
persistence_seq, _ = Postprocess.get_persistence(times_seq, vis_case1.input_dir_pkl) persistence_seq, _ = Postprocess.get_persistence(times_seq, vis_case1.input_dir_pkl)
ts_1_per = (pd.to_datetime(times_0[0]) - datetime.timedelta(hours=23)).strftime("%Y%m%d%H") ts_1_per = (pd.to_datetime(times_0[0]) - datetime.timedelta(hours=23)).strftime("%Y%m%d%H")
...@@ -105,6 +109,30 @@ def test_run_deterministic(vis_case1): ...@@ -105,6 +109,30 @@ def test_run_deterministic(vis_case1):
assert t2_per_max == per_image_max assert t2_per_max == per_image_max
##Test evaluation metric
for ivar, var in enumerate(vis_case1.vars_in):
batch_ds["{0}_persistence_fcst".format(var)].loc[dict(init_time=init_times[0])] = \
persistence_seq[vis_case1.context_frames-1:, :, :, ivar]
eval_metric_ds = vis_case1.populate_eval_metric_ds(eval_metric_ds,batch_ds,sample_ind,vis_case1.vars_in[vis_case1.channel])
##now manuly calculate the mse and see if values is the same as the ones in eval_metric_ds
#calculate the mse between generateed images and reference images
sample_gen = gen_images_denorm[0,vis_case1.context_frames-1:,:,:,vis_case1.channel]
sample_ref = input_images_denorm_all[0,vis_case1.context_frames:,:,:,vis_case1.channel]
sample_gen_ref_mse_t0 = np.mean((sample_gen[0] - sample_ref[0])**2)
metric_name = "2t_savp_mse"
print("eval_metric_ds",eval_metric_ds)
assert eval_metric_ds[metric_name][0,0] == sample_gen_ref_mse_t0
sample_gen_ref_mse_t5 = np.mean((sample_gen[5] - sample_ref[5])**2)
assert eval_metric_ds[metric_name][0,5] == sample_gen_ref_mse_t5
#def test_run_determinstic_quantile_plot(vis_case1): #def test_run_determinstic_quantile_plot(vis_case1):
# vis_case1.init_metric_ds() # vis_case1.init_metric_ds()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment