diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py index 82ba3a6adb900e3456aeffd203180bd429b234f7..fdeb585e60298411bd768fadbb2800dcfd574364 100644 --- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py +++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py @@ -337,11 +337,13 @@ def get_persistence(ts, input_dir_pkl): Inputs: ts: output by generate_seq_timestamps(t_start,len_seq=sequence_length) Is a list containing dateime objects + + input_dir_pkl: input directory to pickle files Ouputs: - ts_persistence: list containing the dates and times of the + time_persistence: list containing the dates and times of the persistence forecast. - peristence_images: sequence of images corresponding to the times + var_peristence : sequence of images corresponding to the times in ts_persistence """ ts_persistence = [] @@ -354,26 +356,40 @@ def get_persistence(ts, input_dir_pkl): month_start = t_persistence_start.month month_end = t_persistence_end.month - if month_start == month_end: - pass - else: # case that we need to derive the data from two pickle files + # only one pickle file is needed (all hours during the same month) + if month_start == month_end: + # Open files to search for the indizes of the corresponding time + time_pickle = load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T') + # Open file to search for the correspoding meteorological fields + var_pickle = load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X') + # Retrieve starting index + ind = list(time_pickle).index(np.array(ts_persistence[0])) + #print('Scarlet, Original', ts_persistence) + #print('From Pickle', time_pickle[ind:ind+len(ts_persistence)]) + + var_persistence = var_pickle[ind:ind+len(ts_persistence)] + time_persistence = time_pickle[ind:ind+len(ts_persistence)].ravel() + print(' Scarlet Shape of time persistence',time_persistence.shape) + #print(' Scarlet Shape of var persistence',var_persistence.shape) + + + # case that we need to derive the data from two pickle files (changing month during the forecast periode) + else: t_persistence_first_m = [] # should hold dates of the first month t_persistence_second_m = [] # should hold dates of the second month + for t in range(len(ts)): m = ts_persistence[t].month if m == month_start: t_persistence_first_m.append(ts_persistence[t]) if m == month_end: t_persistence_second_m.append(ts_persistence[t]) - # Open files to search for the indizes - #path_to_pickle = input_dir_pkl+'/'+str(year_start)+'/T_{:02}.pkl'.format(month_start) - #infile = open(path_to_pickle,'rb') + + # Open files to search for the indizes of the corresponding time time_pickle_first = load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T') time_pickle_second = load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'T') # Open file to search for the correspoding meteorological fields - #path_to_pickle = input_dir_pkl+'/'+str(year_start)+'/X_{:02}.pkl'.format(month_start) - #infile = open(path_to_pickle,'rb') var_pickle_first = load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X') var_pickle_second = load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'X') @@ -381,16 +397,25 @@ def get_persistence(ts, input_dir_pkl): ind_first_m = list(time_pickle_first).index(np.array(t_persistence_first_m[0])) ind_second_m = list(time_pickle_second).index(np.array(t_persistence_second_m[0])) - # construct following indices and store them in a list - DO I NEED THAT? - idx_fist = construct_index(ind_second_m, t_persistence_second_m) - idx_second = construct_index(ind_second_m, t_persistence_second_m) - print('Scarlet this is the index {} in month {}'.format(ind_first_m, month_start)) - print('length sequence first month', len(t_persistence_first_m)) - print('Scarlet this is the index {} in month {}'.format(ind_second_m, month_end)) - print('length sequence second month', len(t_persistence_second_m)) - print('Original', ts_persistence) - print('From Pickle', time_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)]) + #print('Scarlet, Original', ts_persistence) + #print('From Pickle', time_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)], time_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)]) + #print(' Scarlet before', time_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)].shape, time_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)].shape) + # append the sequence of the second month to the first month + var_persistence = np.concatenate((var_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)], + var_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)]), + axis=0) + time_persistence = np.concatenate((time_pickle_first[ind_first_m:ind_first_m+len(t_persistence_first_m)], + time_pickle_second[ind_second_m:ind_second_m+len(t_persistence_second_m)]), + axis=0).ravel() # ravel is needed to eliminate the unnecessary dimension (20,1) becomes (20,) + print(' Scarlet concatenate and ravel (time)', var_persistence.shape, time_persistence.shape) + + + # tolist() is needed for plotting + return var_persistence, time_persistence.tolist() + + + def load_pickle_for_persistence(input_dir_pkl, year_start, month_start, pkl_type): """Helper to get the content of the pickle files. There are two types in our workflow: T_[month].pkl where the time stamp is stored @@ -402,37 +427,6 @@ def load_pickle_for_persistence(input_dir_pkl, year_start, month_start, pkl_type var = pickle.load(infile) return var -def construct_index(first_idx, sequence): - """Helper function to construct the idx sequences from the starting index and the lenght - of the squence. - Wait do I even need a sequence????""" - idx_list = [] - for i in range(first_idx, first_idx+len(sequence)): - #print(i) - idx_list.append(i) - return idx_list - - - # Oder stuff: - # Retrieve indizes - #print(list(time_pickle)) - #for item in time_pickle: - # time_pickle_list.append(item) - #print(time_pickle_list) - #print(time_pickle) - #print(t_persistence_first_m) - #idx_first_m = time_pickle_list.index(t_persistence_first_m) - #print('Scarlet, indizes', idx_first_m) - - #persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts)) - #path_to_pickle = input_dir_pkl+'/'+str(year_start)+'/X_{:02}.pkl'.format(month_start) - #print('hypothetical path', path_to_pickle) - - #infile = open(path_to_pickle,'rb') - #print('Scarlet start persistence and start regular', t_persistence_start, ts[0] ) - - - #return ts_persistence, peristence_images def main(): parser = argparse.ArgumentParser() @@ -582,9 +576,17 @@ def main(): #Generate forecast images plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames+1:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) - #+++ Scarlet 20200828 - get_persistence(ts, input_dir_pkl) - #--- Scarlet 20200528 + #+++ Scarlet 20200922 + print('Scarlet', type(ts[context_frames+1:])) + print('ts', ts[context_frames+1:]) + print('context_frames:', context_frames) + persistence_images, ts_persistence = get_persistence(ts, input_dir_pkl) + print('Scarlet', type(ts_persistence)) + # I am not sure about the number of frames given with context_frames and context_frames +1 + plot_seq_imgs(imgs=persistence_images[context_frames+1:,:,:,0],lats=lats,lons=lons,ts=ts_persistence[context_frames+1:], + label="Persistence Forecast" + args.model,output_png_dir=args.results_dir) + #--- Scarlet 20200922 + #in case of generate the images for all the input, we just generate the first 5 sampe_ind examples for visuliation sample_ind += args.batch_size