diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 015d4b67988f98dd027ab6bf3843408062dcdb5d..90f0e42b13d4457668827877268c00696106bc5c 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -58,7 +58,7 @@ class Postprocess(TrainModel): self.input_dir_tfr = None self.input_dir_pkl = None # forecast products and evaluation metrics to be handled in postprocessing - self.eval_metrics = ["mse", "psnr"] + self.eval_metrics = ["mse", "psnr", "ssim"] self.fcst_products = {"persistence": "pfcst", "model": "mfcst"} # initialize dataset to track evaluation metrics and configure bootstrapping procedure self.eval_metrics_ds = None @@ -393,7 +393,7 @@ class Postprocess(TrainModel): self.stochastic_loss_all_batches = np.mean(np.array(self.stochastic_loss_all_batches), axis=0) assert len(np.array(self.persistent_loss_all_batches).shape) == 1 assert np.array(self.persistent_loss_all_batches).shape[0] == self.future_length - print("Bug here:", np.array(self.stochastic_loss_all_batches).shape) + assert len(np.array(self.stochastic_loss_all_batches).shape) == 2 assert np.array(self.stochastic_loss_all_batches).shape[0] == self.num_stochastic_samples @@ -530,7 +530,7 @@ class Postprocess(TrainModel): # dictionary of implemented evaluation metrics dims = ["lat", "lon"] - known_eval_metrics = {"mse": Scores("mse", dims), "psnr": Scores("psnr", dims)} + known_eval_metrics = {"mse": Scores("mse", dims), "psnr": Scores("psnr", dims),"ssim": Scores("ssim",dims)} # generate list of functions that calculate requested evaluation metrics if set(self.eval_metrics).issubset(known_eval_metrics): diff --git a/video_prediction_tools/model_modules/video_prediction/metrics.py b/video_prediction_tools/model_modules/video_prediction/metrics.py index c807c093b61e6054968f17f8d320dc6e5fec9a50..ef79a72dadef3651ec83607be7cd7e586d8cedef 100644 --- a/video_prediction_tools/model_modules/video_prediction/metrics.py +++ b/video_prediction_tools/model_modules/video_prediction/metrics.py @@ -1,8 +1,7 @@ import tensorflow as tf #import lpips_tf -import numpy as np import math -from skimage.metrics import structural_similarity as ssim +from skimage.metrics import structural_similarity as ssim_ski def mse(a, b): return tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1]) @@ -43,6 +42,6 @@ def ssim_images(image1, image2): :param image1 the reference images :param image2 the predicte images """ - ssim_pred = ssim(image1, image2, + ssim_pred = ssim_ski(image1, image2, data_range = image2.max() - image2.min()) return ssim_pred diff --git a/video_prediction_tools/postprocess/statistical_evaluation.py b/video_prediction_tools/postprocess/statistical_evaluation.py index d91b12cea79305f09fb7f8eaa3cbd00421da3d65..1fe4586428d5a639e4624f02b1fca3117faa292c 100644 --- a/video_prediction_tools/postprocess/statistical_evaluation.py +++ b/video_prediction_tools/postprocess/statistical_evaluation.py @@ -1,8 +1,9 @@ -from typing import Union, Tuple, Dict, List + import numpy as np import xarray as xr -import pandas as pd from typing import Union, List +from skimage.metrics import structural_similarity as ssim + try: from tqdm import tqdm l_tqdm = True @@ -54,7 +55,7 @@ def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length seed: int = 42): """ Performs block bootstrapping on metric along given dimension (e.g. along time dimension) - :param metric: DataArray or dataset of metric that should be bootstrapped + :param metric: DataArray or dataset of metric that should be bootstrapped :param dim_name: name of the dimension on which division into blocks is applied :param block_length: length of block (index-based) :param nboots_block: number of bootstrapping steps to be performed @@ -120,18 +121,18 @@ class Scores: Class to calculate scores and skill scores. """ - known_scores = ["mse", "psnr"] + known_scores = ["mse", "psnr","ssim"] def __init__(self, score_name: str, dims: List[str]): """ Initialize score instance. - :param score_name: name of score taht is queried - :param dims: list of dimension over which the score shall operate - :return: Score instance + :param score_name: name of score that is queried + :param dims: list of dimension over which the score shall operate + :return: Score instance """ method = Scores.__init__.__name__ - self.metrics_dict = {"mse": self.calc_mse_batch , "psnr": self.calc_psnr_batch} + self.metrics_dict = {"mse": self.calc_mse_batch , "psnr": self.calc_psnr_batch, "ssim":self.calc_ssim_batch} if set(self.metrics_dict.keys()) != set(Scores.known_scores): raise ValueError("%{0}: Known scores must coincide with keys of metrics_dict.".format(method)) self.score_name = self.set_score_name(score_name) @@ -193,10 +194,10 @@ class Scores: def calc_psnr_batch(self, data_fcst, data_ref, **kwargs): """ - Calculate mse of forecast data w.r.t. reference data + Calculate psnr of forecast data w.r.t. reference data :param data_fcst: forecasted data (xarray with dimensions [batch, lat, lon]) :param data_ref: reference data (xarray with dimensions [batch, lat, lon]) - :return: averaged mse for each batch example + :return: averaged psnr for each batch example """ method = Scores.calc_mse_batch.__name__ @@ -215,3 +216,16 @@ class Scores: return psnr + def calc_ssim_batch(self, data_fcast, data_ref, **kwargs): + """ + Calculate ssim ealuation metric of forecast data w.r.t reference data + :param data_fcst: forecasted data (xarray with dimensions [batch, lat, lon]) + :param data_ref: reference data (xarray with dimensions [batch, lat, lon]) + :return: averaged ssim for each batch example + """ + method = Scores.calc_ssim_batch.__name__ + + ssim_pred = ssim(data_ref, data_fcast, + data_range = data_fcast.max() - data_fcast.min()) + + return ssim_pred