From a25e78d408f53f968fc7ccecce88a0c54a4476c6 Mon Sep 17 00:00:00 2001
From: michael <m.langguth@fz-juelich.de>
Date: Tue, 26 May 2020 10:25:35 +0200
Subject: [PATCH] Removal of several typos and other bugfixes.

---
 video_prediction/datasets/era5_dataset_v2.py | 28 ++++++++++----------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/video_prediction/datasets/era5_dataset_v2.py b/video_prediction/datasets/era5_dataset_v2.py
index dec2b559..2eac367d 100644
--- a/video_prediction/datasets/era5_dataset_v2.py
+++ b/video_prediction/datasets/era5_dataset_v2.py
@@ -167,7 +167,8 @@ def save_tf_record(output_fname, sequences):
             
 class norm_data:
     
-    knwon_norms = {}
+    ### set known norms and the requested statistics (to be retrieved from statistics.json) here ###
+    known_norms = {}
     known_norms["minmax"] = ["min","max"]
     known_norms["znorm"]  = ["avg","sigma"]
     
@@ -184,8 +185,8 @@ class norm_data:
             for norm_avail in self.known_norms.keys():
                 print(norm_avail)
             raise ValueError("Passed normalization '"+norm+"' is unknown.")
-        
-        if not all(self.varnames in stat_dict):
+       
+        if not all(items in stat_dict for items in self.varnames):
             print("Keys in stat_dict:")
             print(stat_dict.keys())
             
@@ -193,11 +194,14 @@ class norm_data:
             print(self.varnames)
             raise ValueError("Could not find all requested variables in statistics dictionary.")   
 
-        for varname in varnames_uni:
-            for stat_name in knwon_norms[norm]:
+        for varname in self.varnames:
+            for stat_name in self.known_norms[norm]:
                 setattr(self,varname+stat_name,stat_dict[varname][0][stat_name])
                 
         self.status_ok = True
+        for i in range(len(self.varnames)):
+            print(self.varnames[i])
+            print(getattr(self,self.varnames[i]+"min"))
                 
     def norm_var(self,data,varname,norm):
         
@@ -226,9 +230,9 @@ class norm_data:
                 print(norm_avail)
             raise ValueError("Passed normalization '"+norm+"' is unknown.")
         
-        if norm = "minmax":
+        if norm == "minmax":
             return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"max"))
-        elif norm = "znorm":
+        elif norm == "znorm":
             return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg"))
         
 
@@ -237,7 +241,7 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,
     # Include vars_in for more flexible data handling (normalization and reshaping)
     # and optional keyword argument for kind of normalization
     
-    if norm in kwargs:
+    if 'norm' in kwargs:
         norm = kwargs.get("norm")
     else:
         norm = "minmax"
@@ -261,7 +265,7 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,
     
     # open statistics file and store the dictionary
     with open(os.path.join(input_dir,"statistics.json")) as js_file:
-        norm_cls.check_and_set_norm(json.load(js_file),norm_name)        
+        norm_cls.check_and_set_norm(json.load(js_file),norm)        
     
         #if (norm == "minmax"):
             #varmin, varmax = get_stat_allvars(data,"min",vars_in), get_stat_allvars(data,"max",vars_in)
@@ -293,12 +297,8 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,
             ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables
             sequences = np.array(sequences)
             ### normalization
-            # ML 2020/04/08:
-            # again rather inelegant/inefficient as...
-            # a) normalization should be cast in class definition (with initialization, setting of norm. approach including 
-            #    data retrieval and the normalization itself
             for i in range(nvars):    
-                sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm_name)
+                sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm)
 
             output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
             output_fname = os.path.join(output_dir, output_fname)
-- 
GitLab