diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py index 80470df25b5db6854c15d343bf2e62c596b40cb5..c990ba574aba558dbd0735f85f02f82f2118ca08 100644 --- a/scripts/generate_transfer_learning_finetune.py +++ b/scripts/generate_transfer_learning_finetune.py @@ -33,7 +33,7 @@ from matplotlib.colors import LinearSegmentedColormap from skimage.metrics import structural_similarity as ssim import pickle -with open("./splits_size_64_64_1/geo_info.json","r") as json_file: +with open("geo_info.json","r") as json_file: geo = json.load(json_file) lat = [round(i,2) for i in geo["lat"]] lon = [round(i,2) for i in geo["lon"]] @@ -196,190 +196,190 @@ def main(): gen_images_all = [] input_images_all = [] - # while True: - # print("Sample id", sample_ind) - # gen_images_stochastic = [] - # if args.num_samples and sample_ind >= args.num_samples: - # break - # try: - # input_results = sess.run(inputs) - # input_images = input_results["images"] - # input_images_all.extend(input_images) - # with open(os.path.join(args.output_png_dir, "input_images_all"), "wb") as input_files: - # pickle.dump(list(input_images_all), input_files) - # - # except tf.errors.OutOfRangeError: - # break - # - # feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} - # for stochastic_sample_ind in range(args.num_stochastic_samples): - # gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict) - # gen_images_stochastic.append(gen_images) - # print("Stochastic_sample,", stochastic_sample_ind) - # for i in range(args.batch_size): - # print("batch", i) - # #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] - # cmap_name = 'my_list' - # if sample_ind < 20 and i == 1: - # name = 'Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str( - # sample_ind) + " + Sample_" + str(i) - # gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :])) - # #gen_images_ = gen_images[i, :] - # input_images_ = input_images[i, :] - # input_gen_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (gen_images_[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - # - # gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in - # range(sequence_length)] # return the list with 10 (sequence) mse - - # fig = plt.figure(figsize=(18,6)) - # gs = gridspec.GridSpec(1, 10) - # gs.update(wspace = 0., hspace = 0.) - # ts = [0,5,9,10,12,14,16,18,19] - # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] - # ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] - # - # for t in range(len(ts)): - # #if t==0 : ax1=plt.subplot(gs[t]) - # ax1 = plt.subplot(gs[t]) - # input_image = input_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 - # plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300) - # ax1.title.set_text("t = " + str(ts[t]+1)) - # plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) - # - # if t == 0: - # plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels) - # plt.ylabel("Ground Truth", fontsize=10) - # plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg")) - # plt.clf() - # - # fig = plt.figure(figsize=(12,6)) - # gs = gridspec.GridSpec(1, 10) - # gs.update(wspace = 0., hspace = 0.) - # ts = [10,12,14,16,18,19] - # for t in range(len(ts)): - # #if t==0 : ax1=plt.subplot(gs[t]) - # ax1 = plt.subplot(gs[t]) - # gen_image = gen_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 - # plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300) - # ax1.title.set_text("t = " + str(ts[t]+1)) - # plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) - # - # plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg")) - # plt.clf() - # - # fig = plt.figure() - # gs = gridspec.GridSpec(4,6) - # gs.update(wspace = 0.7,hspace=0.8) - # ax1 = plt.subplot(gs[0:2,0:3]) - # ax2 = plt.subplot(gs[0:2,3:],sharey=ax1) - # ax3 = plt.subplot(gs[2:4,0:3]) - # ax4 = plt.subplot(gs[2:4,3:]) - # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] - # ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] - # plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels) - # ax1.title.set_text("(a) Ground Truth") - # ax2.title.set_text("(b) SAVP") - # ax3.title.set_text("(c) Diff.") - # ax4.title.set_text("(d) MSE") - # - # ax1.xaxis.set_tick_params(labelsize=7) - # ax1.yaxis.set_tick_params(labelsize = 7) - # ax2.xaxis.set_tick_params(labelsize=7) - # ax2.yaxis.set_tick_params(labelsize = 7) - # ax3.xaxis.set_tick_params(labelsize=7) - # ax3.yaxis.set_tick_params(labelsize = 7) - # - # init_images = np.zeros((input_images_.shape[1], input_images_.shape[2])) - # print("inti images shape", init_images.shape) - # xdata, ydata = [], [] - # #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1) - # #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1) - # plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) - # plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) - # #x = np.linspace(0, 64, 64) - # #y = np.linspace(0, 64, 64) - # #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) - # #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) - # fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7) - # fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7) - # - # cm = LinearSegmentedColormap.from_list( - # cmap_name, "bwr", N = 5) - # - # plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r', - # #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm) # cmap = 'PuBu_r', - # plot4, = ax4.plot([], [], color = "r") - # ax4.set_xlim(0, future_length-1) - # ax4.set_ylim(0, 20) - # #ax4.set_ylim(0, 0.5) - # ax4.set_xlabel("Frames", fontsize=10) - # #ax4.set_ylabel("MSE", fontsize=10) - # ax4.xaxis.set_tick_params(labelsize=7) - # ax4.yaxis.set_tick_params(labelsize=7) - # - # - # plots = [plot1, plot2, plot3, plot4] - # - # #fig.colorbar(plots[1], ax = [ax1, ax2]) - # - # fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7) - # #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7) - # #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7) - # - # def animation_sample(t): - # input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922 - # gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922 - # diff_image = input_gen_diff[t,:,:] - # # p = sns.lineplot(x=x,y=data,color="b") - # # p.tick_params(labelsize=17) - # # plt.setp(p.lines, linewidth=6) - # plots[0].set_data(input_image) - # plots[1].set_data(gen_image) - # #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) - # #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) - # plots[2].set_data(diff_image) - # - # if t >= future_length: - # #data = gen_mse_avg_[:t + 1] - # # x = list(range(len(gen_mse_avg_)))[:t+1] - # xdata.append(t-future_length) - # print("xdata", xdata) - # ydata.append(gen_mse_avg_[t]) - # print("ydata", ydata) - # plots[3].set_data(xdata, ydata) - # fig.suptitle("Predicted Frame " + str(t-future_length)) - # else: - # #plots[3].set_data(xdata, ydata) - # fig.suptitle("Context Frame " + str(t)) - # return plots - # - # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000, - # repeat_delay=2000) - # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4")) - # - # else: - # pass - - - # - # if sample_ind == 0: - # gen_images_all = gen_images_stochastic - # else: - # gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) - # - # if args.num_stochastic_samples == 1: - # with open(os.path.join(args.output_png_dir, "gen_images_all"), "wb") as gen_files: - # pickle.dump(list(gen_images_all[0]), gen_files) - # else: - # with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: - # pickle.dump(list(gen_images_stochastic), gen_files) - # with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: - # pickle.dump(list(gen_images_all), gen_files) - # - # - # - # - # sample_ind += args.batch_size + while True: + print("Sample id", sample_ind) + gen_images_stochastic = [] + if args.num_samples and sample_ind >= args.num_samples: + break + try: + input_results = sess.run(inputs) + input_images = input_results["images"] + input_images_all.extend(input_images) + with open(os.path.join(args.output_png_dir, "input_images_all"), "wb") as input_files: + pickle.dump(list(input_images_all), input_files) + + except tf.errors.OutOfRangeError: + break + + feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + for stochastic_sample_ind in range(args.num_stochastic_samples): + gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict) + gen_images_stochastic.append(gen_images) + print("Stochastic_sample,", stochastic_sample_ind) + for i in range(args.batch_size): + print("batch", i) + #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] + cmap_name = 'my_list' + if sample_ind < 20 and i == 1: + name = 'Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str( + sample_ind) + " + Sample_" + str(i) + gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :])) + #gen_images_ = gen_images[i, :] + input_images_ = input_images[i, :] + input_gen_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (gen_images_[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) + + gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in + range(sequence_length)] # return the list with 10 (sequence) mse + + fig = plt.figure(figsize=(18,6)) + gs = gridspec.GridSpec(1, 10) + gs.update(wspace = 0., hspace = 0.) + ts = [0,5,9,10,12,14,16,18,19] + xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] + ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] + + for t in range(len(ts)): + #if t==0 : ax1=plt.subplot(gs[t]) + ax1 = plt.subplot(gs[t]) + input_image = input_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300) + ax1.title.set_text("t = " + str(ts[t]+1)) + plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + + if t == 0: + plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels) + plt.ylabel("Ground Truth", fontsize=10) + plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg")) + plt.clf() + + fig = plt.figure(figsize=(12,6)) + gs = gridspec.GridSpec(1, 10) + gs.update(wspace = 0., hspace = 0.) + ts = [10,12,14,16,18,19] + for t in range(len(ts)): + #if t==0 : ax1=plt.subplot(gs[t]) + ax1 = plt.subplot(gs[t]) + gen_image = gen_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300) + ax1.title.set_text("t = " + str(ts[t]+1)) + plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + + plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg")) + plt.clf() + + # fig = plt.figure() + # gs = gridspec.GridSpec(4,6) + # gs.update(wspace = 0.7,hspace=0.8) + # ax1 = plt.subplot(gs[0:2,0:3]) + # ax2 = plt.subplot(gs[0:2,3:],sharey=ax1) + # ax3 = plt.subplot(gs[2:4,0:3]) + # ax4 = plt.subplot(gs[2:4,3:]) + # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] + # ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] + # plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels) + # ax1.title.set_text("(a) Ground Truth") + # ax2.title.set_text("(b) SAVP") + # ax3.title.set_text("(c) Diff.") + # ax4.title.set_text("(d) MSE") + # + # ax1.xaxis.set_tick_params(labelsize=7) + # ax1.yaxis.set_tick_params(labelsize = 7) + # ax2.xaxis.set_tick_params(labelsize=7) + # ax2.yaxis.set_tick_params(labelsize = 7) + # ax3.xaxis.set_tick_params(labelsize=7) + # ax3.yaxis.set_tick_params(labelsize = 7) + # + # init_images = np.zeros((input_images_.shape[1], input_images_.shape[2])) + # print("inti images shape", init_images.shape) + # xdata, ydata = [], [] + # #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1) + # #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1) + # plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) + # plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) + # #x = np.linspace(0, 64, 64) + # #y = np.linspace(0, 64, 64) + # #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) + # #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) + # fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7) + # fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7) + # + # cm = LinearSegmentedColormap.from_list( + # cmap_name, "bwr", N = 5) + # + # plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r', + # #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm) # cmap = 'PuBu_r', + # plot4, = ax4.plot([], [], color = "r") + # ax4.set_xlim(0, future_length-1) + # ax4.set_ylim(0, 20) + # #ax4.set_ylim(0, 0.5) + # ax4.set_xlabel("Frames", fontsize=10) + # #ax4.set_ylabel("MSE", fontsize=10) + # ax4.xaxis.set_tick_params(labelsize=7) + # ax4.yaxis.set_tick_params(labelsize=7) + # + # + # plots = [plot1, plot2, plot3, plot4] + # + # #fig.colorbar(plots[1], ax = [ax1, ax2]) + # + # fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7) + # #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7) + # #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7) + # + # def animation_sample(t): + # input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922 + # gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922 + # diff_image = input_gen_diff[t,:,:] + # # p = sns.lineplot(x=x,y=data,color="b") + # # p.tick_params(labelsize=17) + # # plt.setp(p.lines, linewidth=6) + # plots[0].set_data(input_image) + # plots[1].set_data(gen_image) + # #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) + # #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) + # plots[2].set_data(diff_image) + # + # if t >= future_length: + # #data = gen_mse_avg_[:t + 1] + # # x = list(range(len(gen_mse_avg_)))[:t+1] + # xdata.append(t-future_length) + # print("xdata", xdata) + # ydata.append(gen_mse_avg_[t]) + # print("ydata", ydata) + # plots[3].set_data(xdata, ydata) + # fig.suptitle("Predicted Frame " + str(t-future_length)) + # else: + # #plots[3].set_data(xdata, ydata) + # fig.suptitle("Context Frame " + str(t)) + # return plots + # + # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000, + # repeat_delay=2000) + # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4")) + + else: + pass + + + + if sample_ind == 0: + gen_images_all = gen_images_stochastic + else: + gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) + + if args.num_stochastic_samples == 1: + with open(os.path.join(args.output_png_dir, "gen_images_all"), "wb") as gen_files: + pickle.dump(list(gen_images_all[0]), gen_files) + else: + with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: + pickle.dump(list(gen_images_stochastic), gen_files) + with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: + pickle.dump(list(gen_images_all), gen_files) + + + + + sample_ind += args.batch_size # # for i, gen_mse_avg_ in enumerate(gen_mse_avg): @@ -390,8 +390,7 @@ def main(): # # plt.xlabel("Frames") # # plt.ylabel("MSE_AVG") # # #X = list(range(len(gen_mse_avg_))) - # # #for t, gen_mse_avg_ in enume - # rate(gen_mse_avg): + # # #for t, gen_mse_avg_ in enumerate(gen_mse_avg): # # def animate_metric(j): # # data = gen_mse_avg_[:(j+1)] # # x = list(range(len(gen_mse_avg_)))[:(j+1)]