diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py index 99b134314d47f1086e4473f84c6c8450cb29d58f..5aa43ed7555750134ceb8ad568cd90b8aa85ca14 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py @@ -142,16 +142,17 @@ class ERA5Dataset(BaseDataset): def save_stats(self, variables: list = None, output_dir: str = None, **kwargs): output_file = os.path.join(output_dir, "statistics.json") - stats = {} stats_keys = list(kwargs.keys()) stats_values = list(kwargs.values()) - stats_var = {} #for each variable + stats = {} print("stats_values",stats_values) for i, var in enumerate(variables): + stats_var = {} #for each variable for j, key in enumerate(stats_keys): print("var: {}, key:{}, value: {}".format(var, key,stats_values[j][i])) - stats_var.update({var:{key: float(stats_values[j][i])}}) - stats.update(stats_var) + stats_var.update({key: float(stats_values[j][i])}) + print("stats_var",stats_var) + stats.update({var:stats_var}) print("stats",stats) #save to output directory json.dumps(stats)