diff --git a/test/test_meta_postprocess.py b/test/test_meta_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..4bca02bd5494bc15d8e1d3ff453d2f6237700c08 --- /dev/null +++ b/test/test_meta_postprocess.py @@ -0,0 +1,77 @@ +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong" +__date__ = "2020-12-04" + + +from main_scripts.main_meta_postprocess import * +import os +import pytest + +#Params +analysis_config = "/p/home/jusers/gong1/juwels/ambs/video_prediction_tools/analysis_config/analysis_test.json" +analysis_dir = "/p/home/jusers/gong1/juwels/video_prediction_shared_folder/analysis/bing_test1" +test_nc_fl = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-64x64-3930N0000E-T2_MSL_gph500/convLSTM/20201221T181605_gong1_sunny/vfp_date_2017030118_sample_ind_13.nc" +test_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-64x64-3930N0000E-T2_MSL_gph500/convLSTM/20201221T181605_gong1_sunny" + +#setup instance +@pytest.fixture(scope="module") +def analysis_inst(): + return MetaPostprocess(analysis_config=analysis_config,analysis_dir=analysis_dir) + + +def test_create_analysis_dir(analysis_inst): + analysis_inst.create_analysis_dir() + is_path = os.path.isdir(analysis_dir) + assert is_path == True + + +def test_copy_analysis_config(analysis_inst): + analysis_inst.copy_analysis_config() + file_path = os.path.join(analysis_dir,"analysis_config.json") + is_file_copied = os.path.exists(file_path) + assert is_file_copied == True + + +def test_load_analysis_config(analysis_inst): + analysis_inst.load_analysis_config() + metrics_test = analysis_inst.metrics[0] + test_dir_read = analysis_inst.results_dirs[0] + assert metrics_test == "mse" + assert test_dir_read == test_dir + + +def test_read_values_by_var_from_nc(analysis_inst): + file_nc = test_nc_fl + real,persistent,forecast,time_forecast = analysis_inst.read_values_by_var_from_nc(fl_nc = file_nc) + assert len(real) == len(persistent) == len(forecast) + assert len(time_forecast) == len(forecast) + + +def test_calculate_metric_one_dir(analysis_inst): + file_nc = test_nc_fl + real,persistent,forecast,time_forecast = analysis_inst.read_values_by_var_from_nc(fl_nc = file_nc) + eval_forecast = analysis_inst.calculate_metric_one_img(real,forecast,metric="mse") + +def test_load_results_dir_parameters(analysis_inst): + analysis_inst.load_results_dir_parameters() + assert len(analysis_inst.compare_by_values) == 2 + +def test_calculate_metric_all_dirs(analysis_inst): + analysis_inst.calculate_metric_all_dirs() + +#def test_calculate_metric_all_dirs(analysis_inst): +# analysis_inst.calculate_metrics_all_dirs() +# assert list(analysis_inst.eval_all.keys())[0] == analysis_inst.results_dirs[0] +# print(analysis_inst.eval_all[test_dir].keys()) +# assert len(analysis_inst.eval_all[test_dir]["persistent"]["mse"][1]) == 10 + + +#def test_calculate_mean_vars_forecast(analysis_inst): +# analysis_inst.calculate_metrics_all_dirs() +# analysis_inst.calculate_mean_vars_forecast() +# assert len(analysis_inst.results_dict[test_dir]["forecasts"]) == 2 + + +def test_plot_results(analysis_inst): + analysis_inst.compare_by_values[0] = "savp" + analysis_inst.plot_results() diff --git a/video_prediction_tools/main_scripts/main_meta_postprocess.py b/video_prediction_tools/main_scripts/main_meta_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..044dffe0b1cc3ef94a844158456077f19b88f719 --- /dev/null +++ b/video_prediction_tools/main_scripts/main_meta_postprocess.py @@ -0,0 +1,180 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong" +__date__ = "2020-12-04" + + +import os +from matplotlib.pylab import plt +import json +import numpy as np +import shutil +import glob +from netCDF4 import Dataset +from model_modules.video_prediction.metrics import * +import xarray as xr + + +class MetaPostprocess(object): + def __init__(self, analysis_config=None, analysis_dir=None, stochastic_ind=0, forecast_type="deterministic"): + """ + This class is used for calculating the evaluation metric, analyize the models' results and make comparsion + args: + forecast_type :str, "deterministic" or "stochastic" + """ + + self.analysis_config = analysis_config + self.analysis_dir = analysis_dir + self.stochastic_ind=stochastic_ind + + def __call__(): + self.create_analysis_dir() + self.copy_analysis_config_to_analysis_dir() + self.load_analysis_config() + self.load_results_dir_parameters() + #self.load_prediction_and_persistent_real() + self.calculate_evaluation_metrics() + self.save_metrics_all_dir_to_json() + self.make_comparsion() + + + def create_analysis_dir(self): + if not os.path.exists(self.analysis_dir):os.makedirs(self.analysis_dir) + + def copy_analysis_config(self): + shutil.copy(self.analysis_config, os.path.join(self.analysis_dir,"analysis_config.json")) + self.analysis_config = os.path.join(self.analysis_dir,"analysis_config.json") + + + def load_analysis_config(self): + """ + Get the analysis json configuration file + """ + with open(self.analysis_config) as f: + self.f = json.load(f) + self.metrics = self.f["metric"] + self.results_dirs = self.f["results_dir"] + self.compare_by = self.f["compare_by"] + + @staticmethod + def load_prediction_and_real_from_one_dir(results_dir,var="T2",stochastic_ind=0): + """ + Load the reference and prediction from one results directory + """ + fl_names = glob.glob(os.path.join(results_dir,"*.nc")) + real_all = [] + persistent_all = [] + forecast_all = [] + for fl_nc in fl_names: + real,persistent, forecast,time_forecast = MetaPostprocess.read_values_by_var_from_nc(fl_nc,var,stochastic_ind) + real_all.append(real) + persistent_all.append(persistent) + forecast_all.append(forecast) + return real_all,persistent_all,forecast_all, time_forecast + + + @staticmethod + def read_values_by_var_from_nc(fl_nc,var="T2",stochastic_ind=0): + #if not var in ["T2","MSL","GPH500"]: raise ValueError ("var name is not correct, should be 'T2','MSL',or 'GPH500'") + with Dataset(fl_nc, mode = 'r') as fl: + #load var prediction, real and persistent values + real = fl["/analysis/reference/"].variables[var][:] + persistent = fl["/analysis/persistent/"].variables[var][:] + forecast = fl["/forecasts/"+var+"/stochastic"].variables[str(stochastic_ind)][:] + time_forecast = fl.variables["time_forecast"][:] + return real, persistent, forecast, time_forecast + + + @staticmethod + def calculate_metric_one_img(real,forecast,metric="mse"): + if metric == "mse": + #compare real and persistent + eval_forecast = mse_imgs(real, forecast) + elif metric == "psnr": + eval_forecast = psnr_imgs(real,forecast) + return eval_forecast + + @staticmethod + def reshape_eval_to_one_dim(values): + return np.array(values).flatten() + + def calculate_metric_all_dirs(self,is_persistent=False,metric="mse"): + """ + Return the evaluation metrics for persistent and forecasing model over forecasting timestampls + + return: + eval_forecast: list, the evaluation metric values for persistent with respect to the dimenisons [results_dir,samples,timestampe] + + """ + eval_forecast_all_dirs = [] + for results_dir in self.results_dirs: + real_all, persistent_all, forecast_all, self.time_forecast = MetaPostprocess.load_prediction_and_real_from_one_dir(results_dir,var="T2",stochastic_ind=self.stochastic_ind) + + if is_persistent: forecast_all = persistent_all + eval_forecast_all = [] + #loop for real data + for idx in range(len(real_all)): + eval_forecast_per_sample_over_ts = [] + #loop the forecast time + for time in range(len(self.time_forecast)): + #loop for each sample and each timestamp + eval_forecast = MetaPostprocess.calculate_metric_one_img(real_all[idx][time],forecast_all[idx][time], metric=metric) + eval_forecast_per_sample_over_ts.append(eval_forecast) + + eval_forecast_all.append(list(eval_forecast_per_sample_over_ts)) + eval_forecast_all_dirs.append(eval_forecast_all) + + times = list(range(len(self.time_forecast))) + samples = list(range(len(real_all))) + print("shape of list",np.array(eval_forecast_all_dirs).shape) + evals_forecast = xr.DataArray(eval_forecast_all_dirs, coords=[self.results_dirs, samples , times], dims=["results_dirs", "samples","time_forecast"]) + return evals_forecast + + + def save_metrics_all_dir_to_json(self): + with open("metrics_results.json","w") as f: + json.dump(self.eval_all,f) + + + def load_results_dir_parameters(self,compare_by="model"): + self.compare_by_values = [] + for results_dir in self.results_dirs: + with open(os.path.join(results_dir, "options_checkpoints.json")) as f: + self.options = json.loads(f.read()) + print("self.options:",self.options) + #if self.compare_by == "model": + self.compare_by_values.append(self.options[compare_by]) + + + + + + def plot_results(self,one_persistent=True): + """ + Plot the mean and vars for the user-defined metrics + """ + self.load_results_dir_parameters(compare_by="model") + evals_forecast = self.calculate_metric_all_dirs(is_persistent=False,metric="mse") + t = evals_forecast["time_forecast"] + mean_forecast = evals_forecast.groupby("time_forecast").mean(dim="samples").values + var_forecast = evals_forecast.groupby("time_forecast").var(dim="samples").values + print("mean_foreast",mean_forecast) + x = np.array(self.compare_by_values) + y = np.array(mean_forecast) + e = np.array(var_forecast) + + plt.errorbar(t,y[0],e[0],linestyle="None",marker='^') + plt.show() + plt.savefig(os.path.join(self.analysis_dir,self.metrics[0]+".png")) + plt.close() + + + + + + + + diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 30dd8592a1fb4ff666b5b89306ae471fd24ca44b..1901fdeff7c63303e66e107a859f1b47b9e88331 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -198,8 +198,9 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): self.iterator = tf.data.Iterator.from_string_handle( self.test_handle, self.test_tf_dataset.output_types, self.test_tf_dataset.output_shapes) self.inputs = self.iterator.get_next() - if self.dataset == "era5" and self.model == "savp": - del self.inputs["T_start"] + self.input_ts = self.inputs["T_start"] + #if self.dataset == "era5" and self.model == "savp": + # del self.inputs["T_start"] def check_stochastic_samples_ind_based_on_model(self): @@ -225,16 +226,23 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): """ self.input_results = self.sess.run(self.inputs) self.input_images = self.input_results["images"] + self.t_starts_results = self.input_results["T_start"] + print("t_starts_results:",self.t_starts_results) + self.t_starts = self.t_starts_results #get one seq and the corresponding start time poin - self.t_starts = self.input_results["T_start"] + #self.t_starts = self.input_results["T_start"] + self.input_images_denorm_all = [] for batch_id in range(self.batch_size): self.input_images_ = Postprocess.get_one_seq_from_batch(self.input_images,batch_id) #Renormalized data for inputs ts = Postprocess.generate_seq_timestamps(self.t_starts[batch_id],len_seq=self.sequence_length) - self.input_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl,self.input_images_,self.vars_in) - assert len(self.input_images_denorm.shape) == 4 - Postprocess.plot_seq_imgs(imgs = self.input_images_denorm[self.context_frames+1:,:,:,0],lats=self.lats,lons=self.lons,ts=ts[self.context_frames+1:],label="Ground Truth",output_png_dir=self.results_dir) - return self.input_results, self.input_images,self.t_starts + input_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl,self.input_images_,self.vars_in) + assert len(input_images_denorm.shape) == 4 + Postprocess.plot_seq_imgs(imgs = input_images_denorm[self.context_frames+1:,:,:,0],lats=self.lats,lons=self.lons,ts=ts[self.context_frames+1:],label="Ground Truth",output_png_dir=self.results_dir) + + self.input_images_denorm_all.append(list(input_images_denorm)) + + return self.input_results, np.array(self.input_images_denorm_all),self.t_starts def run(self): @@ -249,7 +257,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): if self.num_samples_per_epoch < self.sample_ind: break else: - self.input_results, self.input_images, self.t_starts = self.run_and_plot_inputs_per_batch() #run the inputs and plot each sequence images + self.input_results, self.input_images_denorm_all, self.t_starts = self.run_and_plot_inputs_per_batch() #run the inputs and plot each sequence images feed_dict = {input_ph: self.input_results[name] for name, input_ph in self.inputs.items()} gen_images_stochastic = [] #[stochastic_ind,batch_size,seq_len,lat,lon,channels] @@ -287,7 +295,8 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): # save input and stochastic generate images to netcdf file # For each prediction (either deterministic or ensemble) we create one netCDF file. for batch_id in range(self.batch_size): - self.save_to_netcdf_for_stochastic_generate_images(self.input_images[batch_id], persistent_images_per_batch[batch_id], + print("batch_id is here",batch_id) + self.save_to_netcdf_for_stochastic_generate_images(self.input_images_denorm_all[batch_id], persistent_images_per_batch[batch_id], np.array(gen_images_stochastic)[:,batch_id,:,:,:,:], fl_name="vfp_date_{}_sample_ind_{}.nc".format(ts_batch[batch_id],self.sample_ind+batch_id)) @@ -406,6 +415,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): gen_images_stochastic: list/array (float), [stochastic_number,seq,lat,lon,channel] fl_name : str, the netcdf file name to be saved """ + print("inputs fpor netcdf:",input_images_) assert (len(np.array(input_images_).shape)==len(np.array(gen_images_stochastic).shape))-1 persistent_images_ = np.array(persistent_images_) assert len(persistent_images_.shape) == 4 #[seq,lat,lon,channel] @@ -679,5 +689,4 @@ def main(): if __name__ == '__main__': - main() - + main()