From a25e78d408f53f968fc7ccecce88a0c54a4476c6 Mon Sep 17 00:00:00 2001 From: michael <m.langguth@fz-juelich.de> Date: Tue, 26 May 2020 10:25:35 +0200 Subject: [PATCH] Removal of several typos and other bugfixes. --- video_prediction/datasets/era5_dataset_v2.py | 28 ++++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/video_prediction/datasets/era5_dataset_v2.py b/video_prediction/datasets/era5_dataset_v2.py index dec2b559..2eac367d 100644 --- a/video_prediction/datasets/era5_dataset_v2.py +++ b/video_prediction/datasets/era5_dataset_v2.py @@ -167,7 +167,8 @@ def save_tf_record(output_fname, sequences): class norm_data: - knwon_norms = {} + ### set known norms and the requested statistics (to be retrieved from statistics.json) here ### + known_norms = {} known_norms["minmax"] = ["min","max"] known_norms["znorm"] = ["avg","sigma"] @@ -184,8 +185,8 @@ class norm_data: for norm_avail in self.known_norms.keys(): print(norm_avail) raise ValueError("Passed normalization '"+norm+"' is unknown.") - - if not all(self.varnames in stat_dict): + + if not all(items in stat_dict for items in self.varnames): print("Keys in stat_dict:") print(stat_dict.keys()) @@ -193,11 +194,14 @@ class norm_data: print(self.varnames) raise ValueError("Could not find all requested variables in statistics dictionary.") - for varname in varnames_uni: - for stat_name in knwon_norms[norm]: + for varname in self.varnames: + for stat_name in self.known_norms[norm]: setattr(self,varname+stat_name,stat_dict[varname][0][stat_name]) self.status_ok = True + for i in range(len(self.varnames)): + print(self.varnames[i]) + print(getattr(self,self.varnames[i]+"min")) def norm_var(self,data,varname,norm): @@ -226,9 +230,9 @@ class norm_data: print(norm_avail) raise ValueError("Passed normalization '"+norm+"' is unknown.") - if norm = "minmax": + if norm == "minmax": return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"max")) - elif norm = "znorm": + elif norm == "znorm": return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg")) @@ -237,7 +241,7 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in, # Include vars_in for more flexible data handling (normalization and reshaping) # and optional keyword argument for kind of normalization - if norm in kwargs: + if 'norm' in kwargs: norm = kwargs.get("norm") else: norm = "minmax" @@ -261,7 +265,7 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in, # open statistics file and store the dictionary with open(os.path.join(input_dir,"statistics.json")) as js_file: - norm_cls.check_and_set_norm(json.load(js_file),norm_name) + norm_cls.check_and_set_norm(json.load(js_file),norm) #if (norm == "minmax"): #varmin, varmax = get_stat_allvars(data,"min",vars_in), get_stat_allvars(data,"max",vars_in) @@ -293,12 +297,8 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in, ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables sequences = np.array(sequences) ### normalization - # ML 2020/04/08: - # again rather inelegant/inefficient as... - # a) normalization should be cast in class definition (with initialization, setting of norm. approach including - # data retrieval and the normalization itself for i in range(nvars): - sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm_name) + sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm) output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1) output_fname = os.path.join(output_dir, output_fname) -- GitLab