diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py index fb61f5f123fd55bcfc8a44652843ee70ba4dfcb8..0e250b47df28d115c8cdfc77fc708eab5e094ce6 100644 --- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py +++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py @@ -109,7 +109,6 @@ def setup_dirs(input_dir,results_png_dir): print ("temporal_dir:",temporal_dir) - def update_hparams_dict(model_hparams_dict,dataset): hparams_dict = dict(model_hparams_dict) hparams_dict.update({ @@ -118,7 +117,7 @@ def update_hparams_dict(model_hparams_dict,dataset): 'repeat': dataset.hparams.time_shift, }) return hparams_dict - + def psnr(img1, img2): mse = np.mean((img1 - img2) ** 2) @@ -159,6 +158,7 @@ def write_params_to_results_dir(args,output_dir,dataset,model): f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4)) return None + def denorm_images(stat_fl, input_images_,channel,var): norm_cls = Norm_data(var) norm = 'minmax' @@ -176,11 +176,12 @@ def denorm_images_all_channels(stat_fl,input_images_,*args): print("args c:", args[c]) input_images_all_channles_denorm.append(denorm_images(stat_fl,input_images_,channel=c,var=args[c])) input_images_denorm = np.stack(input_images_all_channles_denorm, axis=-1) - print("input_images_denorm shape",input_images_denorm.shape) + #print("input_images_denorm shape",input_images_denorm.shape) return input_images_denorm def get_one_seq_and_time(input_images,t_starts,i): - input_images_ = input_images[i, :] + assert (len(np.array(input_images).shape)==5) + input_images_ = input_images[i,:,:,:,:] t_start = t_starts[i] return input_images_,t_start @@ -194,19 +195,20 @@ def generate_seq_timestamps(t_start,len_seq=20): def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,ts,context_frames,future_length,model_name,fl_name="test.nc"): + assert (len(np.array(input_images_).shape)==len(np.array(gen_images_).shape)) y_len = len(lats) x_len = len(lons) ts_len = len(ts) ts_input = ts[:context_frames] ts_forecast = ts[context_frames:] - print("context_frame:",context_frames) - print("future_frame",future_length) - print("length of ts input:",len(ts_input)) - + #print("context_frame:",context_frames) + #print("future_frame",future_length) + #print("length of ts input:",len(ts_input)) + print("input_images_ shape in netcdf,",input_images_.shape) gen_images_ = np.array(gen_images_) - + output_file = os.path.join(output_dir,fl_name) with Dataset(output_file, "w", format="NETCDF4") as nc_file: nc_file.title = 'ERA5 hourly reanalysis data and the forecasting data by deep learning for 2-m above sea level temperatures' @@ -281,7 +283,6 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t gph500_r.units = 'm' gph500_r[:,:,:] = input_images_[context_frames:,:,:,2] - ################ forecast group ##################### @@ -290,21 +291,53 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t t2.units = 'K' t2[:,:,:] = gen_images_[context_frames:,:,:,0] print("NetCDF created") - + #mean sea level pressure msl = nc_file.createVariable("/forecast/{}/MSL".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True) msl.units = 'Pa' msl[:,:,:] = gen_images_[context_frames:,:,:,1] - + #Geopotential at 500 gph500 = nc_file.createVariable("/forecast/{}/GPH500".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True) gph500.units = 'm' gph500[:,:,:] = gen_images_[context_frames:,:,:,2] - - print("{} created".format(output_file)) - + + print("{} created".format(output_file)) + return None +def plot_seq_imgs(imgs,lats,lons,ts,output_png_dir,label="Ground Truth"): + """ + Plot the seq images + """ + + if len(np.array(imgs).shape)!=3:raise("img dims should be four: (seq_len,lat,lon)") + if np.array(imgs).shape[0]!= len(ts): raise("The len of timestamps should be equal the image seq_len") + fig = plt.figure(figsize=(18,6)) + gs = gridspec.GridSpec(1, 10) + gs.update(wspace = 0., hspace = 0.) + xlables = [round(i,2) for i in list(np.linspace(np.min(lons),np.max(lons),5))] + ylabels = [round(i,2) for i in list(np.linspace(np.max(lats),np.min(lats),5))] + for i in range(len(ts)): + t = ts[i] + #if i==0 : ax1=plt.subplot(gs[i]) + ax1 = plt.subplot(gs[i]) + plt.imshow(imgs[i] ,cmap = 'jet', vmin=270, vmax=300) + ax1.title.set_text("t = " + t.strftime("%Y%m%d%H")) + plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + if i == 0: + plt.setp([ax1], xticks = list(np.linspace(0, len(lons), 5)), xticklabels = xlables, yticks = list(np.linspace(0, len(lats), 5)), yticklabels = ylabels) + plt.ylabel(label, fontsize=10) + plt.savefig(os.path.join(output_png_dir, label + "_TS_" + str(ts[0]) + ".jpg")) + plt.clf() + output_fname = label + "_TS_" + ts[0].strftime("%Y%m%d%H") + ".jpg" + print("image {} saved".format(output_fname)) + + +def get_persistence(ts): + pass + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type = str, required = True, @@ -371,7 +404,9 @@ def main(): num_examples_per_epoch = setup_num_samples_per_epoch(args.num_samples,dataset) inputs = dataset.make_batch(args.batch_size) + print("inputs",inputs) input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} + print("input_phs",input_phs) # Build graph @@ -399,6 +434,7 @@ def main(): #---Scarlet:20200803 #while True: #Change True to sample_id<=24 for debugging + #loop for in samples while sample_ind < 5: gen_images_stochastic = [] @@ -407,28 +443,52 @@ def main(): try: input_results = sess.run(inputs) input_images = input_results["images"] + #get the intial times t_starts = input_results["T_start"] - print("T_starts:",t_starts) except tf.errors.OutOfRangeError: break + + #Get prediction values feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel] + #Loop in batch size for i in range(args.batch_size): + #get one seq and the corresponding start time point input_images_,t_start = get_one_seq_and_time(input_images,t_starts,i) #generate time stamps for sequences ts = generate_seq_timestamps(t_start,len_seq=sequence_length) + #Renormalized data for inputs - stat_fl = os.path.join(args.input_dir,"statistics.json") - input_images_denorm = denorm_images_all_channels(stat_fl,input_images_,["T2","MSL","gph500"]) - #TODO: Just for creating the netCDF file and we copy the input_image_denorm as generate_images_denorm before we got our trained data - gen_images_denorm = input_images_denorm #(seq,lat,lon,var) + stat_fl = os.path.join(args.input_dir,"pickle/statistics.json") + input_images_denorm = denorm_images_all_channels(stat_fl,input_images_,["T2","MSL","gph500"]) + print("input_images_denorm",input_images_denorm[0][0]) + + #Renormalized data for inputs + gen_images_ = gen_images[i] + gen_images_denorm = denorm_images_all_channels(stat_fl,gen_images_,["T2","MSL","gph500"]) + print("gene_images_denorm:",gen_images_denorm[0][0]) + #Save input to netCDF file init_date_str = ts[0].strftime("%Y%m%d%H") save_to_netcdf_per_sequence(args.results_dir,input_images_denorm,gen_images_denorm,lons,lats,ts,context_frames,future_length,args.model,fl_name="vfp_{}.nc".format(init_date_str)) + + #Generate images inputs + plot_seq_imgs(imgs=input_images_denorm[:context_frames-1,:,:,0],lats=lats,lons=lons,ts=ts[:context_frames-1],label="Ground Truth",output_png_dir=args.results_dir) + + #Generate forecast images + plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) + #TODO: Scaret plot persistence image + #implment get_persistence() function + + #in case of generate the images for all the input, we just generate the first 5 sampe_ind examples for visuliation + sample_ind += args.batch_size - #for input_image in input_images_: + + + #for input_image in input_images_: # for stochastic_sample_ind in range(args.num_stochastic_samples): # input_images_all.extend(input_images) diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/train_dummy.py index 9c621bbfc46dcf8cebb77b006604ed2f6f50056f..1fb401955c39be4807cf7747e43ed660941cb925 100644 --- a/video_prediction_savp/scripts/train_dummy.py +++ b/video_prediction_savp/scripts/train_dummy.py @@ -14,6 +14,8 @@ from video_prediction import datasets, models import matplotlib.pyplot as plt from json import JSONEncoder import pickle as pkl + + class NumpyArrayEncoder(JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): @@ -261,9 +263,10 @@ def main(): print("parameter_count =", sess.run(parameter_count)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) - model.restore(sess, args.checkpoint) + #model.restore(sess, args.checkpoint) sess.graph.finalize() start_step = sess.run(model.global_step) + print("start_step", start_step) # start at one step earlier to log everything without doing any training # step is relative to the start_step train_losses=[] @@ -286,6 +289,10 @@ def main(): fetches["L_gdl"] = model.L_gdl fetches["L_GAN"] =model.L_GAN + if model.__class__.__name__ == "SAVP": + #todo + pass + fetches["summary"] = model.summary_op results = sess.run(fetches) train_losses.append(results["total_loss"]) diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py index 7560a225e7651728e2ca8d2107d7f32458106c86..7e3fec28dc28c78b8203e1924f17489af8f5075e 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py @@ -65,8 +65,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): original_global_variables = tf.global_variables() # ARCHITECTURE self.convLSTM_network() - print("self.x",self.x) - print("self.x_hat_context_frames,",self.x_hat_context_frames) + #print("self.x",self.x) + #print("self.x_hat_context_frames,",self.x_hat_context_frames) #self.context_frames_loss = tf.reduce_mean( # tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) self.total_loss = tf.reduce_mean( @@ -81,7 +81,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): self.summary_op = tf.summary.merge_all() global_variables = [var for var in tf.global_variables() if var not in original_global_variables] self.saveable_variables = [self.global_step] + global_variables - return + return None @staticmethod