diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py index c4fa831594910b3389a873cc9f8d4dd87944d66e..2a9245ab54e7ad72fdbf153504e2ec507d4688e2 100644 --- a/scripts/generate_transfer_learning_finetune.py +++ b/scripts/generate_transfer_learning_finetune.py @@ -31,6 +31,12 @@ from matplotlib.colors import LinearSegmentedColormap #from video_prediction.utils.ffmpeg_gif import save_gif from skimage.metrics import structural_similarity as ssim import datetime +# Scarlet 2020/05/28: access to statistical values in json file +from os import path +import sys +sys.path.append(path.abspath('../video_prediction/datasets/')) +from era5_dataset_v2 import Norm_data +from os.path import dirname with open("../geo_info.json","r") as json_file: geo = json.load(json_file) @@ -208,8 +214,13 @@ def main(): #X_val = hickle.load("/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/era5-Y2017M01to12-64x64-50d00N11d50E-T_T_T/hickle/splits/X_val.hkl") X_test = hickle.load(os.path.join(temporal_dir,"X_test.hkl")) is_first=True - + #+++Scarlet:20200528 + norm_cls = Norm_data('T2') + norm = 'minmax' + with open(os.path.join(dirname(input_dir),"hickle/splits/statistics.json")) as js_file: + norm_cls.check_and_set_norm(json.load(js_file),norm) + #---Scarlet:20200528 while True: print("Sample id", sample_ind) if sample_ind <= 24: @@ -265,9 +276,11 @@ def main(): input_images_ = input_images[i, :] #Bing:20200417 #persistent_images = ? - input_gen_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (gen_images_[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - persistent_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (persistent_X[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - + #+++Scarlet:20200528 + #print('Scarlet1') + input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm) + persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm) + #---Scarlet:20200528 gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in range(sequence_length)] # return the list with 10 (sequence) mse persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in @@ -284,7 +297,10 @@ def main(): #if t==0 : ax1=plt.subplot(gs[t]) ax1 = plt.subplot(gs[ts.index(t)]) - input_image = input_images_[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + #+++Scarlet:20200528 + #print('Scarlet2') + input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm) + #---Scarlet:20200528 plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300) ax1.title.set_text("t = " + str(t+1-10)) plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) @@ -301,7 +317,10 @@ def main(): for t in ts: #if t==0 : ax1=plt.subplot(gs[t]) ax1 = plt.subplot(gs[ts.index(t)]) - gen_image = gen_images_[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + #+++Scarlet:20200528 + #print('Scarlet3') + gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm) + #---Scarlet:20200528 plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300) ax1.title.set_text("t = " + str(t+1-10)) plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) @@ -538,13 +557,20 @@ def main(): with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files: persistent_images_all = pickle.load(gen_files) + #+++Scarlet:20200528 + #print('Scarlet4') input_images_all = np.array(input_images_all) - input_images_all = np.array(input_images_all) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm) + #---Scarlet:20200528 persistent_images_all = np.array(persistent_images_all) if len(np.array(gen_images_all).shape) == 6: for i in range(len(gen_images_all)): + #+++Scarlet:20200528 + #print('Scarlet5') gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:] - gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm) + #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + #---Scarlet:20200528 mse_all = [] psnr_all = [] ssim_all = [] @@ -574,7 +600,11 @@ def main(): f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape)) else: - gen_images_all = np.array(gen_images_all) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 + #+++Scarlet:20200528 + #print('Scarlet6') + gen_images_all = np.array(gen_images_all) + gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm) + #---Scarlet:20200528 # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2) # look at all timesteps except the first # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)