diff --git a/test/test_visualize_postprocess.py b/test/test_visualize_postprocess.py index aebcda3754dd3599c1da1c7c5b0cec1df8364b94..458430f632cf2d2ccf374ef1ed5e76dc3a7b8b50 100644 --- a/test/test_visualize_postprocess.py +++ b/test/test_visualize_postprocess.py @@ -17,7 +17,8 @@ batch_size = 2 num_samples = 16 num_stochastic_samples = 2 gpu_mem_frac = 0.5 -seed=12345 +seed = 12345 +eval_metrics=["mse", "psnr"] class MyClass: @@ -31,28 +32,25 @@ args = MyClass(results_dir) @pytest.fixture(scope="module") def vis_case1(): return Postprocess(results_dir=results_dir,checkpoint=checkpoint, - mode=mode,batch_size=batch_size,num_samples=num_samples,num_stochastic_samples=num_stochastic_samples, - gpu_mem_frac=gpu_mem_frac,seed=seed,args=args) + mode=mode,batch_size=batch_size, num_samples=num_samples, + num_stochastic_samples=num_stochastic_samples, + seed=seed,args=args,eval_metrics=eval_metrics) ######instance2 num_samples2 = 200000 @pytest.fixture(scope="module") def vis_case2(): - return Postprocess(results_dir=results_dir,checkpoint=checkpoint, - mode=mode,batch_size=batch_size,num_samples=num_samples2,num_stochastic_samples=num_stochastic_samples, - gpu_mem_frac=gpu_mem_frac,seed=seed,args=args) + return Postprocess(results_dir=results_dir, checkpoint=checkpoint, + mode=mode, batch_size=batch_size, num_samples=num_samples2, + num_stochastic_samples=num_stochastic_samples, + seed=seed, args=args,eval_metrics=eval_metrics) def test_load_jsons(vis_case1): - vis_case1.set_seed() - vis_case1.save_args_to_option_json() - vis_case1.copy_data_model_json() - vis_case1.load_jsons() assert vis_case1.dataset == "era5" assert vis_case1.model == "savp" assert vis_case1.input_dir_tfr == "/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/tfrecords_seq_len_24" assert vis_case1.run_mode == "deterministic" def test_get_metadata(vis_case1): - vis_case1.get_metadata() assert vis_case1.height == 56 assert vis_case1.width == 92 assert vis_case1.vars_in[0] == "2t" @@ -60,20 +58,11 @@ def test_get_metadata(vis_case1): def test_setup_test_dataset(vis_case1): - vis_case1.setup_test_dataset() vis_case1.test_dataset.mode == mode - -#def test_copy_data_model_json(vis_case1): -# vis_case1.copy_data_model_json() -# isfile_copy = os.path.isfile(os.path.join(checkpoint,"options.json")) -# assert isfile_copy == True -# isfile_copy_model_hpamas = os.path.isfile(os.path.join(checkpoint,"model_hparams.json")) -# assert isfile_copy_model_hpamas == True - def test_setup_num_samples_per_epoch(vis_case1): vis_case1.setup_test_dataset() - vis_case1.setup_num_samples_per_epoch() + vis_case1.setup_num_samples_per_epoch() assert vis_case1.num_samples_per_epoch == num_samples def test_get_data_params(vis_case1): @@ -86,14 +75,13 @@ def test_run_deterministic(vis_case1): vis_case1.init_session() vis_case1.restore(vis_case1.sess,vis_case1.checkpoint) vis_case1.sample_ind = 0 - vis_case1.init_eval_metrics_list() vis_case1.input_results,vis_case1.input_images_denorm_all, vis_case1.t_starts = vis_case1.run_and_plot_inputs_per_batch() assert len(vis_case1.t_starts_results) == batch_size ts_1 = vis_case1.t_starts[0][0] year = str(ts_1)[:4] month = str(ts_1)[4:6] filename = "ecmwf_era5_" + str(ts_1)[2:] + ".nc" - fl = os.path.join("/p/project/deepacf/deeprain/video_prediction_shared_folder/extractedData",year,month,filename) + fl = os.path.join("/p/project/deepacf/deeprain/video_prediction_shared_folder/extractedData",year, month, filename) print("netCDF file name:",fl) with Dataset(fl,"r") as data_file: t2_var = data_file.variables["2t"][0,:,:] @@ -105,10 +93,11 @@ def test_run_deterministic(vis_case1): input_img_min = np.min(input_image) print("input_image",input_image[0,:10]) assert t2_max == input_img_max - assert t2_min == input_img_min + assert t2_min == input_img_min feed_dict = {input_ph: vis_case1.input_results[name] for name, input_ph in vis_case1.inputs.items()} gen_images = vis_case1.sess.run(vis_case1.video_model.outputs['gen_images'], feed_dict=feed_dict) + ############Test persistenct value############# vis_case1.ts = Postprocess.generate_seq_timestamps(vis_case1.t_starts[0], len_seq=vis_case1.sequence_length) vis_case1.get_and_plot_persistent_per_sample(sample_id=0) @@ -124,6 +113,14 @@ def test_run_deterministic(vis_case1): t2_per_max = np.max(t2_per_var) per_image_max = np.max(vis_case1.persistence_images[0]) assert t2_per_max == per_image_max + + + +def test_run_determinstic_quantile_plot(vis_case1): + vis_case1.init_metric_ds() + + + #def test_make_test_dataset_iterator(vis_case1): # vis_case1.make_test_dataset_iterator() # pass @@ -159,7 +156,4 @@ def test_run_deterministic(vis_case1): -############Test case 2################## - - diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index cc580b3e59e5aeed8e12aba2f4cccce2c870e92e..58cb731f1b83d73e47598684118dee5556e65590 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -31,7 +31,7 @@ from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, crea class Postprocess(TrainModel): def __init__(self, results_dir=None, checkpoint=None, mode="test", batch_size=None, num_stochastic_samples=1, - stochastic_plot_id=0, gpu_mem_frac=None, seed=None, channel=0, args=None, run_mode="deterministic", + stochastic_plot_id=0, seed=None, channel=0, args=None, run_mode="deterministic", eval_metrics=None): """ Initialization of the class instance for postprocessing (generation of forecasts from trained model + @@ -43,7 +43,6 @@ class Postprocess(TrainModel): :param num_stochastic_samples: number of ensemble members for variational models (SAVP, VAE), default: 1 not supported yet!!! :param stochastic_plot_id: not supported yet! - :param gpu_mem_frac: fraction of GPU memory to be pre-allocated :param seed: Integer controlling randomization :param channel: Channel of interest for statistical evaluation :param args: namespace of parsed arguments @@ -54,7 +53,6 @@ class Postprocess(TrainModel): self.results_dir = self.output_dir = os.path.normpath(results_dir) _ = check_dir(self.results_dir, lcreate=True) self.batch_size = batch_size - self.gpu_mem_frac = gpu_mem_frac self.seed = seed self.set_seed() self.num_stochastic_samples = num_stochastic_samples @@ -72,7 +70,7 @@ class Postprocess(TrainModel): self.nboots_block = 1000 self.block_length = 7 * 24 # this corresponds to a block length of 7 days in case of hourly forecasts - # initialize evrything to get an executable Postprocess instance + # initialize everything to get an executable Postprocess instance self.save_args_to_option_json() # create options.json-in results directory self.copy_data_model_json() # copy over JSON-files from model directory # get some parameters related to model and dataset @@ -779,8 +777,8 @@ class Postprocess(TrainModel): Postprocess.clean_obj_attribute(self, "eval_metrics_ds") # the variables for conditional quantile plot - var_fcst = "{0}_{1}_fcst".format(self.vars_in[self.channel], self.model) - var_ref = "{0}_ref".format(self.vars_in[self.channel]) + var_fcst = self.cond_quantile_vars[0] + var_ref = self.cond_quantile_vars[1] data_fcst = get_era5_varatts(self.cond_quantiple_ds[var_fcst], self.cond_quantiple_ds[var_fcst].name) data_ref = get_era5_varatts(self.cond_quantiple_ds[var_ref], self.cond_quantiple_ds[var_ref].name) @@ -1091,7 +1089,7 @@ class Postprocess(TrainModel): if not set(varnames).issubset(ds_in.data_vars): raise ValueError("%{0}: Could not find all variables ({1}) in input dataset ds_in.".format(method, varnames_str)) - + #Bing : why using dtype as an aurument since it seems you only want ton configure dtype as np.double if dtype is None: dtype = np.double else: diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py index 8e003e4533bfb62a943d6b6b9f8d13513558bb88..5d0fb397ef4be4d9c1e8bd3aad6d80bdc1a9925b 100644 --- a/video_prediction_tools/utils/general_utils.py +++ b/video_prediction_tools/utils/general_utils.py @@ -12,7 +12,6 @@ Provides: * get_unique_vars # import modules import os -import sys import numpy as np import xarray as xr @@ -138,7 +137,7 @@ def check_dir(path2dir: str, lcreate=False): if (path2dir is None) or not isinstance(path2dir, str): raise ValueError("%{0}: path2dir must be a string defining a pat to a directory.".format(method)) - if os.path.isdir(path2dir): + elif os.path.isdir(path2dir): return True else: if lcreate: @@ -177,7 +176,7 @@ def provide_default(dict_in, keyname, default=None, required=False): return dict_in[keyname] -def get_era5_varatts(data_arr: xr.DataArray, name): +def get_era5_varatts(data_arr: xr.DataArray, name: str): """ Writes longname and unit to data arrays given their name is known :param data_arr: the data array