Skip to content
Snippets Groups Projects
Commit 40a570fe authored by Bing Gong's avatar Bing Gong
Browse files

update scripts

parent ac96c837
Branches
Tags
No related merge requests found
...@@ -33,7 +33,7 @@ from matplotlib.colors import LinearSegmentedColormap ...@@ -33,7 +33,7 @@ from matplotlib.colors import LinearSegmentedColormap
from skimage.metrics import structural_similarity as ssim from skimage.metrics import structural_similarity as ssim
import pickle import pickle
with open("./splits_size_64_64_1/geo_info.json","r") as json_file: with open("geo_info.json","r") as json_file:
geo = json.load(json_file) geo = json.load(json_file)
lat = [round(i,2) for i in geo["lat"]] lat = [round(i,2) for i in geo["lat"]]
lon = [round(i,2) for i in geo["lon"]] lon = [round(i,2) for i in geo["lon"]]
...@@ -196,77 +196,77 @@ def main(): ...@@ -196,77 +196,77 @@ def main():
gen_images_all = [] gen_images_all = []
input_images_all = [] input_images_all = []
# while True: while True:
# print("Sample id", sample_ind) print("Sample id", sample_ind)
# gen_images_stochastic = [] gen_images_stochastic = []
# if args.num_samples and sample_ind >= args.num_samples: if args.num_samples and sample_ind >= args.num_samples:
# break break
# try: try:
# input_results = sess.run(inputs) input_results = sess.run(inputs)
# input_images = input_results["images"] input_images = input_results["images"]
# input_images_all.extend(input_images) input_images_all.extend(input_images)
# with open(os.path.join(args.output_png_dir, "input_images_all"), "wb") as input_files: with open(os.path.join(args.output_png_dir, "input_images_all"), "wb") as input_files:
# pickle.dump(list(input_images_all), input_files) pickle.dump(list(input_images_all), input_files)
#
# except tf.errors.OutOfRangeError: except tf.errors.OutOfRangeError:
# break break
#
# feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
# for stochastic_sample_ind in range(args.num_stochastic_samples): for stochastic_sample_ind in range(args.num_stochastic_samples):
# gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict) gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)
# gen_images_stochastic.append(gen_images) gen_images_stochastic.append(gen_images)
# print("Stochastic_sample,", stochastic_sample_ind) print("Stochastic_sample,", stochastic_sample_ind)
# for i in range(args.batch_size): for i in range(args.batch_size):
# print("batch", i) print("batch", i)
# #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
# cmap_name = 'my_list' cmap_name = 'my_list'
# if sample_ind < 20 and i == 1: if sample_ind < 20 and i == 1:
# name = 'Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str( name = 'Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str(
# sample_ind) + " + Sample_" + str(i) sample_ind) + " + Sample_" + str(i)
# gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :])) gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :]))
# #gen_images_ = gen_images[i, :] #gen_images_ = gen_images[i, :]
# input_images_ = input_images[i, :] input_images_ = input_images[i, :]
# input_gen_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (gen_images_[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) input_gen_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (gen_images_[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922)
#
# gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
# range(sequence_length)] # return the list with 10 (sequence) mse range(sequence_length)] # return the list with 10 (sequence) mse
fig = plt.figure(figsize=(18,6))
gs = gridspec.GridSpec(1, 10)
gs.update(wspace = 0., hspace = 0.)
ts = [0,5,9,10,12,14,16,18,19]
xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))]
for t in range(len(ts)):
#if t==0 : ax1=plt.subplot(gs[t])
ax1 = plt.subplot(gs[t])
input_image = input_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
ax1.title.set_text("t = " + str(ts[t]+1))
plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
if t == 0:
plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
plt.ylabel("Ground Truth", fontsize=10)
plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
plt.clf()
fig = plt.figure(figsize=(12,6))
gs = gridspec.GridSpec(1, 10)
gs.update(wspace = 0., hspace = 0.)
ts = [10,12,14,16,18,19]
for t in range(len(ts)):
#if t==0 : ax1=plt.subplot(gs[t])
ax1 = plt.subplot(gs[t])
gen_image = gen_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
ax1.title.set_text("t = " + str(ts[t]+1))
plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
plt.clf()
# fig = plt.figure(figsize=(18,6))
# gs = gridspec.GridSpec(1, 10)
# gs.update(wspace = 0., hspace = 0.)
# ts = [0,5,9,10,12,14,16,18,19]
# xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
# ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))]
#
# for t in range(len(ts)):
# #if t==0 : ax1=plt.subplot(gs[t])
# ax1 = plt.subplot(gs[t])
# input_image = input_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
# plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
# ax1.title.set_text("t = " + str(ts[t]+1))
# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
#
# if t == 0:
# plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
# plt.ylabel("Ground Truth", fontsize=10)
# plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
# plt.clf()
#
# fig = plt.figure(figsize=(12,6))
# gs = gridspec.GridSpec(1, 10)
# gs.update(wspace = 0., hspace = 0.)
# ts = [10,12,14,16,18,19]
# for t in range(len(ts)):
# #if t==0 : ax1=plt.subplot(gs[t])
# ax1 = plt.subplot(gs[t])
# gen_image = gen_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
# plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
# ax1.title.set_text("t = " + str(ts[t]+1))
# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
#
# plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
# plt.clf()
#
# fig = plt.figure() # fig = plt.figure()
# gs = gridspec.GridSpec(4,6) # gs = gridspec.GridSpec(4,6)
# gs.update(wspace = 0.7,hspace=0.8) # gs.update(wspace = 0.7,hspace=0.8)
...@@ -356,30 +356,30 @@ def main(): ...@@ -356,30 +356,30 @@ def main():
# ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000, # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000,
# repeat_delay=2000) # repeat_delay=2000)
# ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4")) # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
#
# else:
# pass
else:
pass
#
# if sample_ind == 0:
# gen_images_all = gen_images_stochastic if sample_ind == 0:
# else: gen_images_all = gen_images_stochastic
# gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) else:
# gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1)
# if args.num_stochastic_samples == 1:
# with open(os.path.join(args.output_png_dir, "gen_images_all"), "wb") as gen_files: if args.num_stochastic_samples == 1:
# pickle.dump(list(gen_images_all[0]), gen_files) with open(os.path.join(args.output_png_dir, "gen_images_all"), "wb") as gen_files:
# else: pickle.dump(list(gen_images_all[0]), gen_files)
# with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: else:
# pickle.dump(list(gen_images_stochastic), gen_files) with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files:
# with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: pickle.dump(list(gen_images_stochastic), gen_files)
# pickle.dump(list(gen_images_all), gen_files) with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files:
# pickle.dump(list(gen_images_all), gen_files)
#
#
#
# sample_ind += args.batch_size
sample_ind += args.batch_size
# # for i, gen_mse_avg_ in enumerate(gen_mse_avg): # # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
...@@ -390,8 +390,7 @@ def main(): ...@@ -390,8 +390,7 @@ def main():
# # plt.xlabel("Frames") # # plt.xlabel("Frames")
# # plt.ylabel("MSE_AVG") # # plt.ylabel("MSE_AVG")
# # #X = list(range(len(gen_mse_avg_))) # # #X = list(range(len(gen_mse_avg_)))
# # #for t, gen_mse_avg_ in enume # # #for t, gen_mse_avg_ in enumerate(gen_mse_avg):
# rate(gen_mse_avg):
# # def animate_metric(j): # # def animate_metric(j):
# # data = gen_mse_avg_[:(j+1)] # # data = gen_mse_avg_[:(j+1)]
# # x = list(range(len(gen_mse_avg_)))[:(j+1)] # # x = list(range(len(gen_mse_avg_)))[:(j+1)]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment