Skip to content
Snippets Groups Projects
Commit f3b2ba0e authored by gong1's avatar gong1
Browse files

Bug fix for running twice data iterator per iteration

parent 4e5ceff2
Branches
Tags
No related merge requests found
Pipeline #56252 failed
__email__ = "b.gong@fz-juelich.de"
__author__ = "Bing Gong"
__date__ = "2020-12-04"
from main_scripts.main_meta_postprocess import *
import os
import pytest
#Params
analysis_config = "/p/home/jusers/gong1/juwels/ambs/video_prediction_tools/analysis_config/analysis_test.json"
analysis_dir = "/p/home/jusers/gong1/juwels/video_prediction_shared_folder/analysis/bing_test1"
test_nc_fl = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-64x64-3930N0000E-T2_MSL_gph500/convLSTM/20201221T181605_gong1_sunny/vfp_date_2017030118_sample_ind_13.nc"
test_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-64x64-3930N0000E-T2_MSL_gph500/convLSTM/20201221T181605_gong1_sunny"
#setup instance
@pytest.fixture(scope="module")
def analysis_inst():
return MetaPostprocess(analysis_config=analysis_config,analysis_dir=analysis_dir)
def test_create_analysis_dir(analysis_inst):
analysis_inst.create_analysis_dir()
is_path = os.path.isdir(analysis_dir)
assert is_path == True
def test_copy_analysis_config(analysis_inst):
analysis_inst.copy_analysis_config()
file_path = os.path.join(analysis_dir,"analysis_config.json")
is_file_copied = os.path.exists(file_path)
assert is_file_copied == True
def test_load_analysis_config(analysis_inst):
analysis_inst.load_analysis_config()
metrics_test = analysis_inst.metrics[0]
test_dir_read = analysis_inst.results_dirs[0]
assert metrics_test == "mse"
assert test_dir_read == test_dir
def test_read_values_by_var_from_nc(analysis_inst):
file_nc = test_nc_fl
real,persistent,forecast,time_forecast = analysis_inst.read_values_by_var_from_nc(fl_nc = file_nc)
assert len(real) == len(persistent) == len(forecast)
assert len(time_forecast) == len(forecast)
def test_calculate_metric_one_dir(analysis_inst):
file_nc = test_nc_fl
real,persistent,forecast,time_forecast = analysis_inst.read_values_by_var_from_nc(fl_nc = file_nc)
eval_forecast = analysis_inst.calculate_metric_one_img(real,forecast,metric="mse")
def test_load_results_dir_parameters(analysis_inst):
analysis_inst.load_results_dir_parameters()
assert len(analysis_inst.compare_by_values) == 2
def test_calculate_metric_all_dirs(analysis_inst):
analysis_inst.calculate_metric_all_dirs()
#def test_calculate_metric_all_dirs(analysis_inst):
# analysis_inst.calculate_metrics_all_dirs()
# assert list(analysis_inst.eval_all.keys())[0] == analysis_inst.results_dirs[0]
# print(analysis_inst.eval_all[test_dir].keys())
# assert len(analysis_inst.eval_all[test_dir]["persistent"]["mse"][1]) == 10
#def test_calculate_mean_vars_forecast(analysis_inst):
# analysis_inst.calculate_metrics_all_dirs()
# analysis_inst.calculate_mean_vars_forecast()
# assert len(analysis_inst.results_dict[test_dir]["forecasts"]) == 2
def test_plot_results(analysis_inst):
analysis_inst.compare_by_values[0] = "savp"
analysis_inst.plot_results()
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__email__ = "b.gong@fz-juelich.de"
__author__ = "Bing Gong"
__date__ = "2020-12-04"
import os
from matplotlib.pylab import plt
import json
import numpy as np
import shutil
import glob
from netCDF4 import Dataset
from model_modules.video_prediction.metrics import *
import xarray as xr
class MetaPostprocess(object):
def __init__(self, analysis_config=None, analysis_dir=None, stochastic_ind=0, forecast_type="deterministic"):
"""
This class is used for calculating the evaluation metric, analyize the models' results and make comparsion
args:
forecast_type :str, "deterministic" or "stochastic"
"""
self.analysis_config = analysis_config
self.analysis_dir = analysis_dir
self.stochastic_ind=stochastic_ind
def __call__():
self.create_analysis_dir()
self.copy_analysis_config_to_analysis_dir()
self.load_analysis_config()
self.load_results_dir_parameters()
#self.load_prediction_and_persistent_real()
self.calculate_evaluation_metrics()
self.save_metrics_all_dir_to_json()
self.make_comparsion()
def create_analysis_dir(self):
if not os.path.exists(self.analysis_dir):os.makedirs(self.analysis_dir)
def copy_analysis_config(self):
shutil.copy(self.analysis_config, os.path.join(self.analysis_dir,"analysis_config.json"))
self.analysis_config = os.path.join(self.analysis_dir,"analysis_config.json")
def load_analysis_config(self):
"""
Get the analysis json configuration file
"""
with open(self.analysis_config) as f:
self.f = json.load(f)
self.metrics = self.f["metric"]
self.results_dirs = self.f["results_dir"]
self.compare_by = self.f["compare_by"]
@staticmethod
def load_prediction_and_real_from_one_dir(results_dir,var="T2",stochastic_ind=0):
"""
Load the reference and prediction from one results directory
"""
fl_names = glob.glob(os.path.join(results_dir,"*.nc"))
real_all = []
persistent_all = []
forecast_all = []
for fl_nc in fl_names:
real,persistent, forecast,time_forecast = MetaPostprocess.read_values_by_var_from_nc(fl_nc,var,stochastic_ind)
real_all.append(real)
persistent_all.append(persistent)
forecast_all.append(forecast)
return real_all,persistent_all,forecast_all, time_forecast
@staticmethod
def read_values_by_var_from_nc(fl_nc,var="T2",stochastic_ind=0):
#if not var in ["T2","MSL","GPH500"]: raise ValueError ("var name is not correct, should be 'T2','MSL',or 'GPH500'")
with Dataset(fl_nc, mode = 'r') as fl:
#load var prediction, real and persistent values
real = fl["/analysis/reference/"].variables[var][:]
persistent = fl["/analysis/persistent/"].variables[var][:]
forecast = fl["/forecasts/"+var+"/stochastic"].variables[str(stochastic_ind)][:]
time_forecast = fl.variables["time_forecast"][:]
return real, persistent, forecast, time_forecast
@staticmethod
def calculate_metric_one_img(real,forecast,metric="mse"):
if metric == "mse":
#compare real and persistent
eval_forecast = mse_imgs(real, forecast)
elif metric == "psnr":
eval_forecast = psnr_imgs(real,forecast)
return eval_forecast
@staticmethod
def reshape_eval_to_one_dim(values):
return np.array(values).flatten()
def calculate_metric_all_dirs(self,is_persistent=False,metric="mse"):
"""
Return the evaluation metrics for persistent and forecasing model over forecasting timestampls
return:
eval_forecast: list, the evaluation metric values for persistent with respect to the dimenisons [results_dir,samples,timestampe]
"""
eval_forecast_all_dirs = []
for results_dir in self.results_dirs:
real_all, persistent_all, forecast_all, self.time_forecast = MetaPostprocess.load_prediction_and_real_from_one_dir(results_dir,var="T2",stochastic_ind=self.stochastic_ind)
if is_persistent: forecast_all = persistent_all
eval_forecast_all = []
#loop for real data
for idx in range(len(real_all)):
eval_forecast_per_sample_over_ts = []
#loop the forecast time
for time in range(len(self.time_forecast)):
#loop for each sample and each timestamp
eval_forecast = MetaPostprocess.calculate_metric_one_img(real_all[idx][time],forecast_all[idx][time], metric=metric)
eval_forecast_per_sample_over_ts.append(eval_forecast)
eval_forecast_all.append(list(eval_forecast_per_sample_over_ts))
eval_forecast_all_dirs.append(eval_forecast_all)
times = list(range(len(self.time_forecast)))
samples = list(range(len(real_all)))
print("shape of list",np.array(eval_forecast_all_dirs).shape)
evals_forecast = xr.DataArray(eval_forecast_all_dirs, coords=[self.results_dirs, samples , times], dims=["results_dirs", "samples","time_forecast"])
return evals_forecast
def save_metrics_all_dir_to_json(self):
with open("metrics_results.json","w") as f:
json.dump(self.eval_all,f)
def load_results_dir_parameters(self,compare_by="model"):
self.compare_by_values = []
for results_dir in self.results_dirs:
with open(os.path.join(results_dir, "options_checkpoints.json")) as f:
self.options = json.loads(f.read())
print("self.options:",self.options)
#if self.compare_by == "model":
self.compare_by_values.append(self.options[compare_by])
def plot_results(self,one_persistent=True):
"""
Plot the mean and vars for the user-defined metrics
"""
self.load_results_dir_parameters(compare_by="model")
evals_forecast = self.calculate_metric_all_dirs(is_persistent=False,metric="mse")
t = evals_forecast["time_forecast"]
mean_forecast = evals_forecast.groupby("time_forecast").mean(dim="samples").values
var_forecast = evals_forecast.groupby("time_forecast").var(dim="samples").values
print("mean_foreast",mean_forecast)
x = np.array(self.compare_by_values)
y = np.array(mean_forecast)
e = np.array(var_forecast)
plt.errorbar(t,y[0],e[0],linestyle="None",marker='^')
plt.show()
plt.savefig(os.path.join(self.analysis_dir,self.metrics[0]+".png"))
plt.close()
...@@ -198,8 +198,9 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): ...@@ -198,8 +198,9 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords):
self.iterator = tf.data.Iterator.from_string_handle( self.iterator = tf.data.Iterator.from_string_handle(
self.test_handle, self.test_tf_dataset.output_types, self.test_tf_dataset.output_shapes) self.test_handle, self.test_tf_dataset.output_types, self.test_tf_dataset.output_shapes)
self.inputs = self.iterator.get_next() self.inputs = self.iterator.get_next()
if self.dataset == "era5" and self.model == "savp": self.input_ts = self.inputs["T_start"]
del self.inputs["T_start"] #if self.dataset == "era5" and self.model == "savp":
# del self.inputs["T_start"]
def check_stochastic_samples_ind_based_on_model(self): def check_stochastic_samples_ind_based_on_model(self):
...@@ -225,16 +226,23 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): ...@@ -225,16 +226,23 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords):
""" """
self.input_results = self.sess.run(self.inputs) self.input_results = self.sess.run(self.inputs)
self.input_images = self.input_results["images"] self.input_images = self.input_results["images"]
self.t_starts_results = self.input_results["T_start"]
print("t_starts_results:",self.t_starts_results)
self.t_starts = self.t_starts_results
#get one seq and the corresponding start time poin #get one seq and the corresponding start time poin
self.t_starts = self.input_results["T_start"] #self.t_starts = self.input_results["T_start"]
self.input_images_denorm_all = []
for batch_id in range(self.batch_size): for batch_id in range(self.batch_size):
self.input_images_ = Postprocess.get_one_seq_from_batch(self.input_images,batch_id) self.input_images_ = Postprocess.get_one_seq_from_batch(self.input_images,batch_id)
#Renormalized data for inputs #Renormalized data for inputs
ts = Postprocess.generate_seq_timestamps(self.t_starts[batch_id],len_seq=self.sequence_length) ts = Postprocess.generate_seq_timestamps(self.t_starts[batch_id],len_seq=self.sequence_length)
self.input_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl,self.input_images_,self.vars_in) input_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl,self.input_images_,self.vars_in)
assert len(self.input_images_denorm.shape) == 4 assert len(input_images_denorm.shape) == 4
Postprocess.plot_seq_imgs(imgs = self.input_images_denorm[self.context_frames+1:,:,:,0],lats=self.lats,lons=self.lons,ts=ts[self.context_frames+1:],label="Ground Truth",output_png_dir=self.results_dir) Postprocess.plot_seq_imgs(imgs = input_images_denorm[self.context_frames+1:,:,:,0],lats=self.lats,lons=self.lons,ts=ts[self.context_frames+1:],label="Ground Truth",output_png_dir=self.results_dir)
return self.input_results, self.input_images,self.t_starts
self.input_images_denorm_all.append(list(input_images_denorm))
return self.input_results, np.array(self.input_images_denorm_all),self.t_starts
def run(self): def run(self):
...@@ -249,7 +257,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): ...@@ -249,7 +257,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords):
if self.num_samples_per_epoch < self.sample_ind: if self.num_samples_per_epoch < self.sample_ind:
break break
else: else:
self.input_results, self.input_images, self.t_starts = self.run_and_plot_inputs_per_batch() #run the inputs and plot each sequence images self.input_results, self.input_images_denorm_all, self.t_starts = self.run_and_plot_inputs_per_batch() #run the inputs and plot each sequence images
feed_dict = {input_ph: self.input_results[name] for name, input_ph in self.inputs.items()} feed_dict = {input_ph: self.input_results[name] for name, input_ph in self.inputs.items()}
gen_images_stochastic = [] #[stochastic_ind,batch_size,seq_len,lat,lon,channels] gen_images_stochastic = [] #[stochastic_ind,batch_size,seq_len,lat,lon,channels]
...@@ -287,7 +295,8 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): ...@@ -287,7 +295,8 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords):
# save input and stochastic generate images to netcdf file # save input and stochastic generate images to netcdf file
# For each prediction (either deterministic or ensemble) we create one netCDF file. # For each prediction (either deterministic or ensemble) we create one netCDF file.
for batch_id in range(self.batch_size): for batch_id in range(self.batch_size):
self.save_to_netcdf_for_stochastic_generate_images(self.input_images[batch_id], persistent_images_per_batch[batch_id], print("batch_id is here",batch_id)
self.save_to_netcdf_for_stochastic_generate_images(self.input_images_denorm_all[batch_id], persistent_images_per_batch[batch_id],
np.array(gen_images_stochastic)[:,batch_id,:,:,:,:], np.array(gen_images_stochastic)[:,batch_id,:,:,:,:],
fl_name="vfp_date_{}_sample_ind_{}.nc".format(ts_batch[batch_id],self.sample_ind+batch_id)) fl_name="vfp_date_{}_sample_ind_{}.nc".format(ts_batch[batch_id],self.sample_ind+batch_id))
...@@ -406,6 +415,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords): ...@@ -406,6 +415,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords):
gen_images_stochastic: list/array (float), [stochastic_number,seq,lat,lon,channel] gen_images_stochastic: list/array (float), [stochastic_number,seq,lat,lon,channel]
fl_name : str, the netcdf file name to be saved fl_name : str, the netcdf file name to be saved
""" """
print("inputs fpor netcdf:",input_images_)
assert (len(np.array(input_images_).shape)==len(np.array(gen_images_stochastic).shape))-1 assert (len(np.array(input_images_).shape)==len(np.array(gen_images_stochastic).shape))-1
persistent_images_ = np.array(persistent_images_) persistent_images_ = np.array(persistent_images_)
assert len(persistent_images_.shape) == 4 #[seq,lat,lon,channel] assert len(persistent_images_.shape) == 4 #[seq,lat,lon,channel]
...@@ -680,4 +690,3 @@ def main(): ...@@ -680,4 +690,3 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment