Skip to content
Snippets Groups Projects
Commit a25e78d4 authored by Michael Langguth's avatar Michael Langguth
Browse files

Removal of several typos and other bugfixes.

parent 8ac40fec
Branches
Tags
No related merge requests found
...@@ -167,7 +167,8 @@ def save_tf_record(output_fname, sequences): ...@@ -167,7 +167,8 @@ def save_tf_record(output_fname, sequences):
class norm_data: class norm_data:
knwon_norms = {} ### set known norms and the requested statistics (to be retrieved from statistics.json) here ###
known_norms = {}
known_norms["minmax"] = ["min","max"] known_norms["minmax"] = ["min","max"]
known_norms["znorm"] = ["avg","sigma"] known_norms["znorm"] = ["avg","sigma"]
...@@ -185,7 +186,7 @@ class norm_data: ...@@ -185,7 +186,7 @@ class norm_data:
print(norm_avail) print(norm_avail)
raise ValueError("Passed normalization '"+norm+"' is unknown.") raise ValueError("Passed normalization '"+norm+"' is unknown.")
if not all(self.varnames in stat_dict): if not all(items in stat_dict for items in self.varnames):
print("Keys in stat_dict:") print("Keys in stat_dict:")
print(stat_dict.keys()) print(stat_dict.keys())
...@@ -193,11 +194,14 @@ class norm_data: ...@@ -193,11 +194,14 @@ class norm_data:
print(self.varnames) print(self.varnames)
raise ValueError("Could not find all requested variables in statistics dictionary.") raise ValueError("Could not find all requested variables in statistics dictionary.")
for varname in varnames_uni: for varname in self.varnames:
for stat_name in knwon_norms[norm]: for stat_name in self.known_norms[norm]:
setattr(self,varname+stat_name,stat_dict[varname][0][stat_name]) setattr(self,varname+stat_name,stat_dict[varname][0][stat_name])
self.status_ok = True self.status_ok = True
for i in range(len(self.varnames)):
print(self.varnames[i])
print(getattr(self,self.varnames[i]+"min"))
def norm_var(self,data,varname,norm): def norm_var(self,data,varname,norm):
...@@ -226,9 +230,9 @@ class norm_data: ...@@ -226,9 +230,9 @@ class norm_data:
print(norm_avail) print(norm_avail)
raise ValueError("Passed normalization '"+norm+"' is unknown.") raise ValueError("Passed normalization '"+norm+"' is unknown.")
if norm = "minmax": if norm == "minmax":
return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"max")) return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"max"))
elif norm = "znorm": elif norm == "znorm":
return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg")) return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg"))
...@@ -237,7 +241,7 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in, ...@@ -237,7 +241,7 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,
# Include vars_in for more flexible data handling (normalization and reshaping) # Include vars_in for more flexible data handling (normalization and reshaping)
# and optional keyword argument for kind of normalization # and optional keyword argument for kind of normalization
if norm in kwargs: if 'norm' in kwargs:
norm = kwargs.get("norm") norm = kwargs.get("norm")
else: else:
norm = "minmax" norm = "minmax"
...@@ -261,7 +265,7 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in, ...@@ -261,7 +265,7 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,
# open statistics file and store the dictionary # open statistics file and store the dictionary
with open(os.path.join(input_dir,"statistics.json")) as js_file: with open(os.path.join(input_dir,"statistics.json")) as js_file:
norm_cls.check_and_set_norm(json.load(js_file),norm_name) norm_cls.check_and_set_norm(json.load(js_file),norm)
#if (norm == "minmax"): #if (norm == "minmax"):
#varmin, varmax = get_stat_allvars(data,"min",vars_in), get_stat_allvars(data,"max",vars_in) #varmin, varmax = get_stat_allvars(data,"min",vars_in), get_stat_allvars(data,"max",vars_in)
...@@ -293,12 +297,8 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in, ...@@ -293,12 +297,8 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,
###Normalization should adpot the selected variables, here we used duplicated channel temperature variables ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables
sequences = np.array(sequences) sequences = np.array(sequences)
### normalization ### normalization
# ML 2020/04/08:
# again rather inelegant/inefficient as...
# a) normalization should be cast in class definition (with initialization, setting of norm. approach including
# data retrieval and the normalization itself
for i in range(nvars): for i in range(nvars):
sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm_name) sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm)
output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1) output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
output_fname = os.path.join(output_dir, output_fname) output_fname = os.path.join(output_dir, output_fname)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment