diff --git a/video_prediction/datasets/era5_dataset_v2.py b/video_prediction/datasets/era5_dataset_v2.py index 63a11d2292e7c91dbf1801ec3507b971c64cf700..379e5662a58af91dee726fcff366464eaa84e956 100644 --- a/video_prediction/datasets/era5_dataset_v2.py +++ b/video_prediction/datasets/era5_dataset_v2.py @@ -201,7 +201,7 @@ class norm_data: self.status_ok = True - def normalize_var(self,data,varname,norm): + def norm_var(self,data,varname,norm): # some sanity checks if not self.status_ok: raise ValueError("norm_data-object needs to be initialized and checked first.") @@ -217,7 +217,7 @@ class norm_data: elif norm == "znorm": return((data[...] - getattr(self,varname+"avg"))/getattr(self,varname+"sigma")**2) - def denormalize_var(self,data,varname,norm): + def denorm_var(self,data,varname,norm): # some sanity checks if not self.status_ok: raise ValueError("norm_data-object needs to be initialized and checked first.") @@ -238,27 +238,28 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in, # ML 2020/04/08: # Include vars_in for more flexible data handling (normalization and reshaping) # and optional keyword argument for kind of normalization - known_norms = ["minmax"] # may be more elegant to define a class here? + + if n output_dir = os.path.join(output_dir,partition_name) os.makedirs(output_dir,exist_ok=True) - norm = norm_data(vars_in) + norm_cls = norm_data(vars_in) nvars = len(vars_in) - vars_uni, indrev = np.unique(vars_in,return_inverse=True) - if 'norm' in kwargs: - norm = kwargs.get("norm") - if (not norm in knwon_norms): - raise ValueError("Pass valid normalization identifier.") - print("Known identifiers are: ") - for norm_name in known_norm: - print('"'+norm_name+'"') - else: - norm = "minmax" + #vars_uni, indrev = np.unique(vars_in,return_inverse=True) + #if 'norm' in kwargs: + #norm = kwargs.get("norm") + #if (not norm in knwon_norms): + #raise ValueError("Pass valid normalization identifier.") + #print("Known identifiers are: ") + #for norm_name in known_norm: + #print('"'+norm_name+'"') + #else: + #norm = "minmax" # open statistics file and store the dictionary with open(os.path.join(input_dir,"statistics.json")) as js_file: - norm.check_and_set_norm(json.load(js_file),norm_name) + norm_cls.check_and_set_norm(json.load(js_file),norm_name) #if (norm == "minmax"): #varmin, varmax = get_stat_allvars(data,"min",vars_in), get_stat_allvars(data,"max",vars_in) @@ -295,7 +296,7 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in, # a) normalization should be cast in class definition (with initialization, setting of norm. approach including # data retrieval and the normalization itself for i in range(nvars): - sequences[:,:,:,:,i] = norm.normalize_var(sequences[:,:,:,:,i],vars_in[i],norm_name) + sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm_name) output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1) output_fname = os.path.join(output_dir, output_fname)