diff --git a/video_prediction_savp/metadata.py b/video_prediction_savp/metadata.py index c8cf648fe602ad602f76cc9d41754e7607f28d14..da0beb4f571c4ed9a81d94f0cbadfcce53ff02a0 100644 --- a/video_prediction_savp/metadata.py +++ b/video_prediction_savp/metadata.py @@ -23,7 +23,9 @@ class MetaData: method_name = MetaData.__init__.__name__+" of Class "+MetaData.__name__ if not json_file is None: - MetaData.get_metadata_from_file(json_file) + print(json_file) + print(type(json_file)) + MetaData.get_metadata_from_file(self,json_file) else: # No dictionary from json-file available, all other arguments have to set @@ -93,8 +95,8 @@ class MetaData: self.lat = datafile.variables['lat'][slices['lat_s']:slices['lat_e']] self.lon = datafile.variables['lon'][slices['lon_s']:slices['lon_e']] - # Now start constructing exp_dir-string - # switch sign and coordinate-flags to avoid negative values appearing in exp_dir-name + # Now start constructing expdir-string + # switch sign and coordinate-flags to avoid negative values appearing in expdir-name if sw_c[0] < 0.: sw_c[0] = np.abs(sw_c[0]) flag_coords[0] = "S" @@ -114,7 +116,7 @@ class MetaData: expdir, expname = path_parts[0], path_parts[1] - # extend exp_dir_in successively (splitted up for better readability) + # extend expdir_in successively (splitted up for better readability) expname += "-"+str(self.nx) + "x" + str(self.ny) expname += "-"+(("{0: 05.2f}"+flag_coords[0]+"{1:05.2f}"+flag_coords[1]).format(*sw_c)).strip().replace(".","")+"-" @@ -200,22 +202,24 @@ class MetaData: with open(js_file) as js_file: dict_in = json.load(js_file) - self.exp_dir = dict_in["exp_dir"] + self.expdir = dict_in["expdir"] self.sw_c = [dict_in["sw_corner_frame"]["lat"],dict_in["sw_corner_frame"]["lon"] ] self.lat = dict_in["coordinates"]["lat"] - self.lat = dict_in["coordinates"]["lon"] + self.lon = dict_in["coordinates"]["lon"] self.nx = dict_in["frame_size"]["nx"] self.ny = dict_in["frame_size"]["ny"] - - self.variables = [dict_in["variables"][ivar] for ivar in dict_in["variables"].keys()] - + # dict_in["variables"] is a list like [{var1: varname1},{var2: varname2},...] + list_of_dict_aux = dict_in["variables"] + # iterate through the list with an integer ivar + # note: the naming of the variables starts with var1, thus add 1 to the iterator + self.variables = [list_of_dict_aux[ivar]["var"+str(ivar+1)] for ivar in range(len(list_of_dict_aux))] def write_dirs_to_batch_scripts(self,batch_script): """ - Expands ('known') directory-variables in batch_script by exp_dir-attribute of class instance + Expands ('known') directory-variables in batch_script by expdir-attribute of class instance """ paths_to_mod = ["source_dir=","destination_dir=","checkpoint_dir=","results_dir="] # known directory-variables in batch-scripts diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py index 09bf0afe2c46049d2a1b3e19667daed97b7a5419..13b93889875779942e5171e5e1d98eebc84fd9f3 100644 --- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py +++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py @@ -36,6 +36,7 @@ sys.path.append(path.abspath('../video_prediction/datasets/')) from era5_dataset_v2 import Norm_data from os.path import dirname from netCDF4 import Dataset,date2num +from metadata import MetaData as MetaData def set_seed(seed): if seed is not None: @@ -43,14 +44,20 @@ def set_seed(seed): np.random.seed(seed) random.seed(seed) -#TODO: WE MUST REPLACE IT WITH Micha's meta.json files -with open("../geo_info.json","r") as json_file: - geo = json.load(json_file) - lat = [round(i,2) for i in geo["lat"]] - lon = [round(i,2) for i in geo["lon"]] +def get_coordinates(metadata_fname): + """ + Retrieves the latitudes and longitudes read from the metadata json file. + """ + md = MetaData(json_file=metadata_fname) + md.get_metadata_from_file(metadata_fname) + + try: + print("lat:",md.lat) + print("lon:",md.lon) + return md.lat, md.lon + except: + raise ValueError("Error when handling: '"+metadata_fname+"'") -print("lat:",lat) -print("lon:",lon) def load_checkpoints_and_create_output_dirs(checkpoint,dataset,model): if checkpoint: @@ -185,17 +192,7 @@ def generate_seq_timestamps(t_start,len_seq=20): seq_ts = [s_datetime + datetime. timedelta(hours = i+1) for i in range(len_seq)] return seq_ts - -def generate_coordinates(lat,lon,input_images_): - if len(input_images_.shape) !=4: raise ("The length of input_images_ should be equal to 4, but return shape:",np.array(input_images_).shape) - input_shape = input_images_.shape - len_ts,len_lat,len_lon,len_channel = input_shape[0],input_shape[1],input_shape[2],input_shape[3] - lons = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),len_lon))] - lats = [round(i,2) for i in list(np.linspace(np.min(lat),np.max(lat),len_lat))] - print("lenght of lons : {}; lenght of lats: {}".format(len(lons),len(lats))) - return lats,lons - def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,ts, model_name,fl_name="test.nc"): y_len = len(lats) @@ -319,7 +316,8 @@ def main(): print('------------------------------------- End --------------------------------------') #setup dataset and model object - dataset = setup_dataset(dataset,args.input_dir,args.mode,args.seed,args.num_epochs,args.dataset_hparams,dataset_hparams_dict) + input_dir_tf = os.path.join(args.input_dir, "tfrecords") # where tensorflow records are stored + dataset = setup_dataset(dataset,input_dir_tf,args.mode,args.seed,args.num_epochs,args.dataset_hparams,dataset_hparams_dict) print("Step 2 finished") VideoPredictionModel = models.get_model_class(model) @@ -364,10 +362,10 @@ def main(): sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data() is_first=True - #+++Scarlet:20200528 - + #+++Scarlet:20200803 + lats, lons = get_coordinates(os.path.join(args.input_dir,"metadata.json")) - #---Scarlet:20200528 + #---Scarlet:20200803 #while True: #Change True to sample_id<=24 for debugging #loop for in samples @@ -388,12 +386,9 @@ def main(): #get one seq and the corresponding start time point input_images_,t_start = get_one_seq_and_time(input_images,t_starts,i) #generate time stamps for sequences - ts = generate_seq_timestamps(t_start,len_seq=sequence_length)[context_frames-1:] #This will include the intia time - #genereate coordinates - if sample_ind==0: lats,lons = generate_coordinates(lat,lon,input_images_) - print("lats:",lats) + ts = generate_seq_timestamps(t_start,len_seq=sequence_length)[context_frames-1:] #This will include the intia time #Renormalized data for inputs - stat_fl = os.path.join(dirname(args.input_dir),"hickle/statistics.json") + stat_fl = os.path.join(args.input_dir,"statistics.json") input_images_denorm = denorm_images_all_channels(stat_fl,input_images_,["T2","MSL","gph500"])[context_frames:,:,:,:] #TODO: Just for creating the netCDF file and we copy the input_image_denorm as generate_images_denorm before we got our trained data gen_images_denorm = input_images_denorm #(seq,lat,lon,var)