diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
index 99b134314d47f1086e4473f84c6c8450cb29d58f..5aa43ed7555750134ceb8ad568cd90b8aa85ca14 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
@@ -142,16 +142,17 @@ class ERA5Dataset(BaseDataset):
     def save_stats(self, variables: list = None, output_dir: str = None, **kwargs):
 
         output_file = os.path.join(output_dir, "statistics.json")
-        stats = {}
         stats_keys = list(kwargs.keys())
         stats_values = list(kwargs.values())
-        stats_var = {} #for each variable
+        stats = {}
         print("stats_values",stats_values)
         for i, var in enumerate(variables):
+            stats_var = {} #for each variable
             for j, key in enumerate(stats_keys):
                 print("var: {}, key:{}, value: {}".format(var, key,stats_values[j][i]))
-                stats_var.update({var:{key: float(stats_values[j][i])}})
-            stats.update(stats_var)
+                stats_var.update({key: float(stats_values[j][i])})
+                print("stats_var",stats_var)
+            stats.update({var:stats_var})
         print("stats",stats)
         #save to output directory
         json.dumps(stats)