Skip to content
Snippets Groups Projects
Commit aabaf749 authored by gong1's avatar gong1
Browse files

Adopt the unitest for main_visulization_postprocess for Michael update version

parent c9200cc1
Branches
Tags
No related merge requests found
Pipeline #69459 passed
...@@ -59,9 +59,9 @@ def test_run_deterministic(vis_case1): ...@@ -59,9 +59,9 @@ def test_run_deterministic(vis_case1):
vis_case1.init_session() vis_case1.init_session()
vis_case1.restore(vis_case1.sess,vis_case1.checkpoint) vis_case1.restore(vis_case1.sess,vis_case1.checkpoint)
vis_case1.sample_ind = 0 vis_case1.sample_ind = 0
vis_case1.input_results,vis_case1.input_images_denorm_all, vis_case1.t_starts = vis_case1.get_input_data_per_batch(vis_case1.inputs) input_results,input_images_denorm_all,t_starts = vis_case1.get_input_data_per_batch(vis_case1.inputs)
assert len(vis_case1.t_starts) == batch_size assert len(t_starts) == batch_size
ts_1 = vis_case1.t_starts[0][0] ts_1 = t_starts[0][0]
year = str(ts_1)[:4] year = str(ts_1)[:4]
month = str(ts_1)[4:6] month = str(ts_1)[4:6]
filename = "ecmwf_era5_" + str(ts_1)[2:] + ".nc" filename = "ecmwf_era5_" + str(ts_1)[2:] + ".nc"
...@@ -72,36 +72,42 @@ def test_run_deterministic(vis_case1): ...@@ -72,36 +72,42 @@ def test_run_deterministic(vis_case1):
t2_var = np.array(t2_var) t2_var = np.array(t2_var)
t2_max = np.max(t2_var[117:173,0:92]) t2_max = np.max(t2_var[117:173,0:92])
t2_min = np.min(t2_var[117:173,0:92]) t2_min = np.min(t2_var[117:173,0:92])
input_image = np.array(vis_case1.input_images_denorm_all)[0,0,:,:,0] #get the first batch id and 1st sequence image input_image = np.array(input_images_denorm_all)[0,0,:,:,0] #get the first batch id and 1st sequence image
input_img_max = np.max(input_image) input_img_max = np.max(input_image)
input_img_min = np.min(input_image) input_img_min = np.min(input_image)
print("input_image",input_image[0,:10]) print("input_image",input_image[0,:10])
assert t2_max == input_img_max assert t2_max == input_img_max
assert t2_min == input_img_min assert t2_min == input_img_min
sample_ind = 0
feed_dict = {input_ph: vis_case1.input_results[name] for name, input_ph in vis_case1.inputs.items()} feed_dict = {input_ph: input_results[name] for name, input_ph in vis_case1.inputs.items()}
gen_images = vis_case1.sess.run(vis_case1.video_model.outputs['gen_images'], feed_dict=feed_dict) gen_images = vis_case1.sess.run(vis_case1.video_model.outputs['gen_images'], feed_dict=feed_dict)
gen_images_denorm = vis_case1.denorm_images_all_channels(gen_images, vis_case1.vars_in, vis_case1.norm_cls,
norm_method="minmax")
############Test persistenct value############# ############Test persistenct value#############
vis_case1.ts = Postprocess.generate_seq_timestamps(vis_case1.t_starts[0], len_seq=vis_case1.sequence_length) times_0, init_times = vis_case1.get_init_time(t_starts)
vis_case1.get_and_plot_persistent_per_sample(sample_id=0) batch_ds = vis_case1.create_dataset(input_images_denorm_all, gen_images_denorm, init_times)
ts_1_per = (datetime.datetime.strptime(str(ts_1), '%Y%m%d%H') - datetime.timedelta(hours=23)).strftime("%Y%m%d%H") nbs = np.minimum(vis_case1.batch_size, vis_case1.num_samples_per_epoch - sample_ind)
times_seq = (pd.date_range(times_0[0], periods=int(vis_case1.sequence_length), freq="h")).to_pydatetime()
persistence_seq, _ = Postprocess.get_persistence(times_seq, vis_case1.input_dir_pkl)
ts_1_per = (pd.to_datetime(times_0[0]) - datetime.timedelta(hours=23)).strftime("%Y%m%d%H")
year_per = str(ts_1_per)[:4] year_per = str(ts_1_per)[:4]
month_per = str(ts_1_per)[4:6] month_per = str(ts_1_per)[4:6]
filename_per = "ecmwf_era5_" + str(ts_1_per)[2:] + ".nc" filename_per = "ecmwf_era5_" + str(ts_1_per)[2:] + ".nc"
fl_per = os.path.join("/p/project/deepacf/deeprain/video_prediction_shared_folder/extractedData",year_per,month_per,filename_per)
fl_per = os.path.join("/p/scratch/deepacf/deeprain/ambs_era5/extractedData",year_per,month_per,filename_per)
with Dataset(fl_per,"r") as data_file: with Dataset(fl_per,"r") as data_file:
t2_var_per = data_file.variables["2t"][0,117:173,0:92] t2_var_per = data_file.variables["2t"][0,117:173,0:92]
t2_per_var = np.array(t2_var_per) t2_per_var = np.array(t2_var_per)
t2_per_max = np.max(t2_per_var) t2_per_max = np.max(t2_per_var)
per_image_max = np.max(vis_case1.persistence_images[0]) per_image_max = np.max(persistence_seq[0])
assert t2_per_max == per_image_max assert t2_per_max == per_image_max
def test_run_determinstic_quantile_plot(vis_case1): #def test_run_determinstic_quantile_plot(vis_case1):
vis_case1.init_metric_ds() # vis_case1.init_metric_ds()
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment