From c759d326a4896d32f3ef842ad1df0a3a3006e8a3 Mon Sep 17 00:00:00 2001
From: masak1112 <gongbing1112@gmail.com>
Date: Wed, 13 Jul 2022 15:39:19 +0200
Subject: [PATCH] fix bugs to save statsitics information to output directory
 during training process

---
 .../video_prediction/datasets/era5_dataset.py            | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
index 99b13431..5aa43ed7 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
@@ -142,16 +142,17 @@ class ERA5Dataset(BaseDataset):
     def save_stats(self, variables: list = None, output_dir: str = None, **kwargs):
 
         output_file = os.path.join(output_dir, "statistics.json")
-        stats = {}
         stats_keys = list(kwargs.keys())
         stats_values = list(kwargs.values())
-        stats_var = {} #for each variable
+        stats = {}
         print("stats_values",stats_values)
         for i, var in enumerate(variables):
+            stats_var = {} #for each variable
             for j, key in enumerate(stats_keys):
                 print("var: {}, key:{}, value: {}".format(var, key,stats_values[j][i]))
-                stats_var.update({var:{key: float(stats_values[j][i])}})
-            stats.update(stats_var)
+                stats_var.update({key: float(stats_values[j][i])})
+                print("stats_var",stats_var)
+            stats.update({var:stats_var})
         print("stats",stats)
         #save to output directory
         json.dumps(stats)
-- 
GitLab