diff --git a/test/test_meta_postprocess.py b/test/test_meta_postprocess.py index 79883adb09450aebf4d5525f8e15357c7598e58c..1892b0c3c22d723eeba5e08cce238a5177e3860b 100644 --- a/test/test_meta_postprocess.py +++ b/test/test_meta_postprocess.py @@ -11,6 +11,11 @@ import pytest analysis_config = "/p/home/jusers/gong1/juwels/ambs/video_prediction_tools/analysis_config/analysis_test.json" analysis_dir = "/p/home/jusers/gong1/juwels/video_prediction_shared_folder/analysis/bing_test1" +test_nc_fl = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-64x64-3930N0000E-T2_MSL_gph500/convLSTM/20201221T181605_gong1_sunny/vfp_date_2017030118_sample_ind_13.nc" +test_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-64x64-3930N0000E-T2_MSL_gph500/convLSTM/20201221T181605_gong1_sunny" + + + #setup instance @pytest.fixture(scope="module") @@ -34,38 +39,41 @@ def test_copy_analysis_config(analysis_inst): def test_load_analysis_config(analysis_inst): analysis_inst.load_analysis_config() metrics_test = analysis_inst.metrics[0] + test_dir_read = analysis_inst.results_dirs[0] assert metrics_test == "mse" + assert test_dir_read == test_dir def test_read_values_by_var_from_nc(analysis_inst): - file_nc = "/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1/vfp_date_2017031312_sample_ind_229.nc" + file_nc = test_nc_fl real,persistent,forecast,time_forecast = analysis_inst.read_values_by_var_from_nc(fl_nc = file_nc) assert len(real) == len(persistent) == len(forecast) assert len(time_forecast) == len(forecast) def test_calculate_metric_one_dir(analysis_inst): - file_nc = "/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1/vfp_date_2017031312_sample_ind_229.nc" + file_nc = test_nc_fl real,persistent,forecast,time_forecast = analysis_inst.read_values_by_var_from_nc(fl_nc = file_nc) - eval_persistent,eval_forecast = analysis_inst.calculate_metric_one_img(real, persistent,forecast,metric="mse") + eval_forecast = analysis_inst.calculate_metric_one_img(real,forecast,metric="mse") def test_load_results_dir_parameters(analysis_inst): analysis_inst.load_results_dir_parameters() assert len(analysis_inst.compare_by_values) == 2 - def test_calculate_metric_all_dirs(analysis_inst): - analysis_inst.calculate_metrics_all_dirs() - assert list(analysis_inst.eval_all.keys())[0] == analysis_inst.results_dirs[0] - print(analysis_inst.eval_all["/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1"].keys()) - assert len(analysis_inst.eval_all["/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1"]["persistent"]["mse"][1]) == 10 + analysis_inst.calculate_metric_all_dirs() + +#def test_calculate_metric_all_dirs(analysis_inst): +# analysis_inst.calculate_metrics_all_dirs() +# assert list(analysis_inst.eval_all.keys())[0] == analysis_inst.results_dirs[0] +# print(analysis_inst.eval_all[test_dir].keys()) +# assert len(analysis_inst.eval_all[test_dir]["persistent"]["mse"][1]) == 10 -def test_calculate_mean_vars_forecast(analysis_inst): - analysis_inst.calculate_metrics_all_dirs() - analysis_inst.calculate_mean_vars_forecast() - - assert len(analysis_inst.results_dict["/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1"]["forecast"]) == 2 +#def test_calculate_mean_vars_forecast(analysis_inst): +# analysis_inst.calculate_metrics_all_dirs() +# analysis_inst.calculate_mean_vars_forecast() +# assert len(analysis_inst.results_dict[test_dir]["forecasts"]) == 2 def test_plot_results(analysis_inst): diff --git a/video_prediction_tools/main_scripts/main_meta_postprocess.py b/video_prediction_tools/main_scripts/main_meta_postprocess.py index 084759765243642edc5f710798e2809688656d13..51f6e19aaa3c0525a6304b38251474d0793a0a8f 100644 --- a/video_prediction_tools/main_scripts/main_meta_postprocess.py +++ b/video_prediction_tools/main_scripts/main_meta_postprocess.py @@ -14,7 +14,9 @@ import numpy as np import shutil import glob from netCDF4 import Dataset -from video_prediction.metrics import * +from model_modules.video_prediction.metrics import * +import xarray as xr + class MetaPostprocess(object): def __init__(self, analysis_config=None, analysis_dir=None, stochastic_ind=0, forecast_type="deterministic"): @@ -81,140 +83,72 @@ class MetaPostprocess(object): #load var prediction, real and persistent values real = fl["/analysis/reference/"].variables[var][:] persistent = fl["/analysis/persistent/"].variables[var][:] - forecast = fl["/forecast/"+var+"/stochastic"].variables[str(stochastic_ind)][:] + forecast = fl["/forecasts/"+var+"/stochastic"].variables[str(stochastic_ind)][:] time_forecast = fl.variables["time_forecast"][:] return real, persistent, forecast, time_forecast @staticmethod - def calculate_metric_one_img(real, persistent,forecast,metric="mse"): + def calculate_metric_one_img(real,forecast,metric="mse"): if metric == "mse": #compare real and persistent - eval_persistent = mse_imgs(real,persistent) eval_forecast = mse_imgs(real, forecast) elif metric == "psnr": - eval_persistent = psnr_imgs(real,forecast) eval_forecast = psnr_imgs(real,forecast) - return eval_persistent, eval_forecast + return eval_forecast @staticmethod def reshape_eval_to_one_dim(values): return np.array(values).flatten() - def calculate_metrics_all_dirs(self): - """ - Calculate the all the metrics for persistent and forecast results - eval_all is dictionary, - eval_all = { - <results_dir>: - { - "persistent": - { - <metric_name1> : eval_values, - <metric_name2> : eval_values - } - - "forecast" : - { - <metric_name1> : eval_values, - <metric_name2> : eval_values - } - - } - } - + def calculate_metric_all_dirs(self,is_persistent=False,metric="mse"): """ - self.eval_all = {} - for results_dir in self.results_dirs: - self.eval_all.update({results_dir: {"persistent":None}}) - self.eval_all.update({results_dir: {"forecast":None}}) - real_all, persistent_all, forecast_all, self.time_forecast = MetaPostprocess.load_prediction_and_real_from_one_dir(results_dir,var="T2",stochastic_ind=self.stochastic_ind) - for metric in self.metrics: - self.eval_persistent_all = [] - self.eval_forecast_all = [] - #loop for real data - for idx in range(len(real_all)): - eval_persistent_per_sample_over_ts = [] - eval_forecast_per_sample_over_ts = [] - - #loop the forecast time - for time in range(len(self.time_forecast)): - #loop for each sample and each timestamp - self.eval_persistent, self.eval_forecast = MetaPostprocess.calculate_metric_one_img(real_all[idx][time],persistent_all[idx][time],forecast_all[idx][time], metric=metric) - eval_persistent_per_sample_over_ts.append(self.eval_persistent) - eval_forecast_per_sample_over_ts.append(self.eval_forecast) - - self.eval_persistent_all.append(list(eval_persistent_per_sample_over_ts)) - self.eval_forecast_all.append(list(eval_forecast_per_sample_over_ts)) - #the shape of self.eval_persistent_all is [samples,time_forecast] - self.eval_all[results_dir]["persistent"] = {metric: list(self.eval_persistent_all)} - self.eval_all[results_dir]["forecast"] = {metric: list(self.eval_forecast_all)} + Return the evaluation metrics for persistent and forecasing model over forecasting timestampls - def save_metrics_all_dir_to_json(self): - with open("metrics_results.json","w") as f: - json.dump(self.eval_all,f) - - - def load_results_dir_parameters(self,compare_by="model"): - self.compare_by_values = [] - for results_dir in self.results_dirs: - with open(os.path.join(results_dir, "options_checkpoints.json")) as f: - self.options = json.loads(f.read()) - print("self.options:",self.options) - #if self.compare_by == "model": - self.compare_by_values.append(self.options[compare_by]) - - - def calculate_mean_vars_forecast(self): - """ - Calculate the mean varations of persistent and forecast evalaution metrics + return: + eval_forecast: list, the evaluation metric values for persistent with respect to the dimenisons [results_dir,samples,timestampe] + """ - is_first_persistent = False + eval_forecast_all_dirs = [] for results_dir in self.results_dirs: - evals = self.eval_all[results_dir] - eval_persistent = evals["persistent"] - eval_forecast = evals["forecast"] - self.results_dict = {} - for metric in self.metrics: - err_stat = [] + real_all, persistent_all, forecast_all, self.time_forecast = MetaPostprocess.load_prediction_and_real_from_one_dir(results_dir,var="T2",stochastic_ind=self.stochastic_ind) + + if is_persistent: forecast_all = persistent_all + eval_forecast_all = [] + #loop for real data + for idx in range(len(real_all)): + eval_forecast_per_sample_over_ts = [] + #loop the forecast time for time in range(len(self.time_forecast)): - forecast_values_all = list(eval_forecast[metric])[:][time] - persistent_values_all = list(eval_persistent[metric])[:][time] - forecast_mean = np.mean(np.array(forecast_values_all),axis=0) - persistent_mean = np.mean(np.array(persistent_values_all),axis=0) - forecast_vars = np.var(np.array(forecast_values_all),axis=0) - persistent_vars = np.var(np.array(persistent_values_all),axis=0) - #[time,mean,vars] - self.results_dict[results_dir] = {"persistent":[persistent_mean, persistent_vars]} - self.results_dict[results_dir].update({"forecast":[forecast_mean,forecast_vars]}) - + #loop for each sample and each timestamp + eval_forecast = MetaPostprocess.calculate_metric_one_img(real_all[idx][time],forecast_all[idx][time], metric=metric) + eval_forecast_per_sample_over_ts.append(eval_forecast) + + eval_forecast_all.append(list(eval_forecast_per_sample_over_ts)) + eval_forecast_all_dirs.append(eval_forecast_all) + + times = list(range(len(self.time_forecast))) + samples = list(range(len(real_all))) + print("shape of list",np.array(eval_forecast_all_dirs).shape) + evals_forecast = xr.DataArray(eval_forecast_all_dirs, coords=[self.results_dirs, samples , times], dims=["results_dirs", "samples","time_forecast"]) + return evals_forecast + def plot_results(self,one_persistent=True): """ Plot the mean and vars for the user-defined metrics """ - - self.load_results_dir_parameters() - is_first_persistent=True - mean_all_persistent = [] - vars_all_persistent = [] - mean_all_model = [] - vars_all_model = [] - for results_dir in self.results_dirs: - mean_all_model.append(self.results_dict[results_dir]["forecast"][0]) - vars_all_model.append(self.results_dict[results_dir]["forecast"][1]) - - if one_persistent==True: - mean_all_model.append(self.results_dict[results_dir]["persistent"][0]) - vars_all_model.append(self.results_dict[results_dir]["persistent"][1]) - self.compare_by_values.append("persistent") - - + self.load_results_dir_parameters(compare_by="model") + evals_forecast = self.calculate_metric_all_dirs(is_persistent=False,metric="mse") + t = evals_forecast["time_forecast"] + mean_forecast = evals_forecast.groupby("time_forecast").mean(dim="samples").values + var_forecast = evals_forecast.groupby("time_forecast").var(dim="samples").values + print("mean_foreast",mean_forecast) x = np.array(self.compare_by_values) - y = np.array(mean_all_model) - e = np.array(vars_all_model) - - plt.errorbar(x,y,e,linestyle="None",marker='^') + y = np.array(mean_forecast) + e = np.array(var_forecast) + + plt.errorbar(t,y[0],e[0],linestyle="None",marker='^') plt.show() plt.savefig(os.path.join(self.analysis_dir,self.metrics[0]+".png")) plt.close() diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 382cd6a7c82db3b263c28b01d85bc19ecb248cc1..c78eecf87fa704bf06c700825910914f63110094 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -199,8 +199,8 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): self.test_handle, self.test_tf_dataset.output_types, self.test_tf_dataset.output_shapes) self.inputs = self.iterator.get_next() self.input_ts = self.inputs["T_start"] - if self.dataset == "era5" and self.model == "savp": - del self.inputs["T_start"] + #if self.dataset == "era5" and self.model == "savp": + # del self.inputs["T_start"] def check_stochastic_samples_ind_based_on_model(self): @@ -226,19 +226,23 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): """ self.input_results = self.sess.run(self.inputs) self.input_images = self.input_results["images"] - self.t_starts_results = self.sess.run(self.input_ts) + self.t_starts_results = self.input_results["T_start"] print("t_starts_results:",self.t_starts_results) self.t_starts = self.t_starts_results #get one seq and the corresponding start time poin #self.t_starts = self.input_results["T_start"] + self.input_images_denorm_all = [] for batch_id in range(self.batch_size): self.input_images_ = Postprocess.get_one_seq_from_batch(self.input_images,batch_id) #Renormalized data for inputs ts = Postprocess.generate_seq_timestamps(self.t_starts[batch_id],len_seq=self.sequence_length) - self.input_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl,self.input_images_,self.vars_in) - assert len(self.input_images_denorm.shape) == 4 - Postprocess.plot_seq_imgs(imgs = self.input_images_denorm[self.context_frames+1:,:,:,0],lats=self.lats,lons=self.lons,ts=ts[self.context_frames+1:],label="Ground Truth",output_png_dir=self.results_dir) - return self.input_results, self.input_images,self.t_starts + input_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl,self.input_images_,self.vars_in) + assert len(input_images_denorm.shape) == 4 + Postprocess.plot_seq_imgs(imgs = input_images_denorm[self.context_frames+1:,:,:,0],lats=self.lats,lons=self.lons,ts=ts[self.context_frames+1:],label="Ground Truth",output_png_dir=self.results_dir) + + self.input_images_denorm_all.append(list(input_images_denorm)) + + return self.input_results, np.array(self.input_images_denorm_all),self.t_starts def run(self): @@ -253,7 +257,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): if self.num_samples_per_epoch < self.sample_ind: break else: - self.input_results, self.input_images, self.t_starts = self.run_and_plot_inputs_per_batch() #run the inputs and plot each sequence images + self.input_results, self.input_images_denorm_all, self.t_starts = self.run_and_plot_inputs_per_batch() #run the inputs and plot each sequence images feed_dict = {input_ph: self.input_results[name] for name, input_ph in self.inputs.items()} gen_images_stochastic = [] #[stochastic_ind,batch_size,seq_len,lat,lon,channels] @@ -293,7 +297,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): print("persistent_images_per_batch",len(np.array(persistent_images_per_batch))) for batch_id in range(self.batch_size): print("batch_id is here",batch_id) - self.save_to_netcdf_for_stochastic_generate_images(self.input_images[batch_id], persistent_images_per_batch[batch_id], + self.save_to_netcdf_for_stochastic_generate_images(self.input_images_denorm_all[batch_id], persistent_images_per_batch[batch_id], np.array(gen_images_stochastic)[:,batch_id,:,:,:,:], fl_name="vfp_date_{}_sample_ind_{}.nc".format(ts_batch[batch_id],self.sample_ind+batch_id)) @@ -412,6 +416,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): gen_images_stochastic: list/array (float), [stochastic_number,seq,lat,lon,channel] fl_name : str, the netcdf file name to be saved """ + print("inputs fpor netcdf:",input_images_) assert (len(np.array(input_images_).shape)==len(np.array(gen_images_stochastic).shape))-1 persistent_images_ = np.array(persistent_images_) assert len(persistent_images_.shape) == 4 #[seq,lat,lon,channel] @@ -709,5 +714,4 @@ def main(): if __name__ == '__main__': - main() - + main()