diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py index 7948d650cc270c9ee33dcd32c2e03c70f9216225..5460d4f8768d0d93ffac3a42ff3345334ebb2026 100644 --- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py +++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py @@ -329,8 +329,102 @@ def plot_seq_imgs(imgs,lats,lons,ts,output_png_dir,label="Ground Truth"): print("image {} saved".format(output_fname)) -def get_persistence(ts): - pass +def get_persistence(ts, input_dir_pkl): + """This function gets the persistence forecast. + 'Today's weather will be like yesterday's weather. + + Inputs: + ts: output by generate_seq_timestamps(t_start,len_seq=sequence_length) + Is a list containing dateime objects + + input_dir_pkl: input directory to pickle files + + Ouputs: + time_persistence: list containing the dates and times of the + persistence forecast. + var_peristence : sequence of images corresponding to the times + in ts_persistence + """ + ts_persistence = [] + for t in range(len(ts)): # Scarlet: this certainly can be made nicer with list comprehension + ts_temp = ts[t] - datetime.timedelta(days=1) + ts_persistence.append(ts_temp) + t_persistence_start = ts_persistence[0] + t_persistence_end = ts_persistence[-1] + year_start = t_persistence_start.year + month_start = t_persistence_start.month + month_end = t_persistence_end.month + + # only one pickle file is needed (all hours during the same month) + if month_start == month_end: + # Open files to search for the indizes of the corresponding time + time_pickle = load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T') + # Open file to search for the correspoding meteorological fields + var_pickle = load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X') + # Retrieve starting index + ind = list(time_pickle).index(np.array(ts_persistence[0])) + #print('Scarlet, Original', ts_persistence) + #print('From Pickle', time_pickle[ind:ind+len(ts_persistence)]) + + var_persistence = var_pickle[ind:ind+len(ts_persistence)] + time_persistence = time_pickle[ind:ind+len(ts_persistence)].ravel() + print(' Scarlet Shape of time persistence',time_persistence.shape) + #print(' Scarlet Shape of var persistence',var_persistence.shape) + + + # case that we need to derive the data from two pickle files (changing month during the forecast periode) + else: + t_persistence_first_m = [] # should hold dates of the first month + t_persistence_second_m = [] # should hold dates of the second month + + for t in range(len(ts)): + m = ts_persistence[t].month + if m == month_start: + t_persistence_first_m.append(ts_persistence[t]) + if m == month_end: + t_persistence_second_m.append(ts_persistence[t]) + + # Open files to search for the indizes of the corresponding time + time_pickle_first = load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T') + time_pickle_second = load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'T') + + # Open file to search for the correspoding meteorological fields + var_pickle_first = load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X') + var_pickle_second = load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'X') + + # Retrieve starting index + ind_first_m = list(time_pickle_first).index(np.array(t_persistence_first_m[0])) + ind_second_m = list(time_pickle_second).index(np.array(t_persistence_second_m[0])) + + #print('Scarlet, Original', ts_persistence) + #print('From Pickle', time_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)], time_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)]) + #print(' Scarlet before', time_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)].shape, time_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)].shape) + + # append the sequence of the second month to the first month + var_persistence = np.concatenate((var_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)], + var_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)]), + axis=0) + time_persistence = np.concatenate((time_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)], + time_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)]), + axis=0).ravel() # ravel is needed to eliminate the unnecessary dimension (20,1) becomes (20,) + print(' Scarlet concatenate and ravel (time)', var_persistence.shape, time_persistence.shape) + + + # tolist() is needed for plotting + return var_persistence, time_persistence.tolist() + + + +def load_pickle_for_persistence(input_dir_pkl, year_start, month_start, pkl_type): + """Helper to get the content of the pickle files. There are two types in our workflow: + T_[month].pkl where the time stamp is stored + X_[month].pkl where the variables are stored, e.g. temperature, geopotential and pressure + This helper function constructs the directory, opens the file to read it, returns the variable. + """ + path_to_pickle = input_dir_pkl+'/'+str(year_start)+'/'+pkl_type+'_{:02}.pkl'.format(month_start) + infile = open(path_to_pickle,'rb') + var = pickle.load(infile) + return var def main(): @@ -377,6 +471,11 @@ def main(): input_dir_tf = os.path.join(args.input_dir, "tfrecords") # where tensorflow records are stored dataset = setup_dataset(dataset,input_dir_tf,args.mode,args.seed,args.num_epochs,args.dataset_hparams,dataset_hparams_dict) + # +++Scarlet 20200828 + input_dir_pkl = os.path.join(args.input_dir, "pickle") + # where pickle files records are stored, needed for the persistance forecast. + # ---Scarlet 20200828 + print("Step 2 finished") VideoPredictionModel = models.get_model_class(model) @@ -476,9 +575,17 @@ def main(): #Generate forecast images plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames+1:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) - #TODO: Scaret plot persistence image - #implment get_persistence() function - + #+++ Scarlet 20200922 + print('Scarlet', type(ts[context_frames+1:])) + print('ts', ts[context_frames+1:]) + print('context_frames:', context_frames) + persistence_images, ts_persistence = get_persistence(ts, input_dir_pkl) + print('Scarlet', type(ts_persistence)) + # I am not sure about the number of frames given with context_frames and context_frames +1 + plot_seq_imgs(imgs=persistence_images[context_frames+1:,:,:,0],lats=lats,lons=lons,ts=ts_persistence[context_frames+1:], + label="Persistence Forecast" + args.model,output_png_dir=args.results_dir) + #--- Scarlet 20200922 + #in case of generate the images for all the input, we just generate the first 5 sampe_ind examples for visuliation sample_ind += args.batch_size