Skip to content
Snippets Groups Projects
Commit 77e28763 authored by BING GONG's avatar BING GONG
Browse files

add ssim evaluation to main_visulatioz_postprocess

parent af430f46
Branches
Tags
No related merge requests found
Pipeline #68444 passed
...@@ -58,7 +58,7 @@ class Postprocess(TrainModel): ...@@ -58,7 +58,7 @@ class Postprocess(TrainModel):
self.input_dir_tfr = None self.input_dir_tfr = None
self.input_dir_pkl = None self.input_dir_pkl = None
# forecast products and evaluation metrics to be handled in postprocessing # forecast products and evaluation metrics to be handled in postprocessing
self.eval_metrics = ["mse", "psnr"] self.eval_metrics = ["mse", "psnr", "ssim"]
self.fcst_products = {"persistence": "pfcst", "model": "mfcst"} self.fcst_products = {"persistence": "pfcst", "model": "mfcst"}
# initialize dataset to track evaluation metrics and configure bootstrapping procedure # initialize dataset to track evaluation metrics and configure bootstrapping procedure
self.eval_metrics_ds = None self.eval_metrics_ds = None
...@@ -393,7 +393,7 @@ class Postprocess(TrainModel): ...@@ -393,7 +393,7 @@ class Postprocess(TrainModel):
self.stochastic_loss_all_batches = np.mean(np.array(self.stochastic_loss_all_batches), axis=0) self.stochastic_loss_all_batches = np.mean(np.array(self.stochastic_loss_all_batches), axis=0)
assert len(np.array(self.persistent_loss_all_batches).shape) == 1 assert len(np.array(self.persistent_loss_all_batches).shape) == 1
assert np.array(self.persistent_loss_all_batches).shape[0] == self.future_length assert np.array(self.persistent_loss_all_batches).shape[0] == self.future_length
print("Bug here:", np.array(self.stochastic_loss_all_batches).shape)
assert len(np.array(self.stochastic_loss_all_batches).shape) == 2 assert len(np.array(self.stochastic_loss_all_batches).shape) == 2
assert np.array(self.stochastic_loss_all_batches).shape[0] == self.num_stochastic_samples assert np.array(self.stochastic_loss_all_batches).shape[0] == self.num_stochastic_samples
...@@ -530,7 +530,7 @@ class Postprocess(TrainModel): ...@@ -530,7 +530,7 @@ class Postprocess(TrainModel):
# dictionary of implemented evaluation metrics # dictionary of implemented evaluation metrics
dims = ["lat", "lon"] dims = ["lat", "lon"]
known_eval_metrics = {"mse": Scores("mse", dims), "psnr": Scores("psnr", dims)} known_eval_metrics = {"mse": Scores("mse", dims), "psnr": Scores("psnr", dims),"ssim": Scores("ssim",dims)}
# generate list of functions that calculate requested evaluation metrics # generate list of functions that calculate requested evaluation metrics
if set(self.eval_metrics).issubset(known_eval_metrics): if set(self.eval_metrics).issubset(known_eval_metrics):
......
import tensorflow as tf import tensorflow as tf
#import lpips_tf #import lpips_tf
import numpy as np
import math import math
from skimage.metrics import structural_similarity as ssim from skimage.metrics import structural_similarity as ssim_ski
def mse(a, b): def mse(a, b):
return tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1]) return tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1])
...@@ -43,6 +42,6 @@ def ssim_images(image1, image2): ...@@ -43,6 +42,6 @@ def ssim_images(image1, image2):
:param image1 the reference images :param image1 the reference images
:param image2 the predicte images :param image2 the predicte images
""" """
ssim_pred = ssim(image1, image2, ssim_pred = ssim_ski(image1, image2,
data_range = image2.max() - image2.min()) data_range = image2.max() - image2.min())
return ssim_pred return ssim_pred
from typing import Union, Tuple, Dict, List
import numpy as np import numpy as np
import xarray as xr import xarray as xr
import pandas as pd
from typing import Union, List from typing import Union, List
from skimage.metrics import structural_similarity as ssim
try: try:
from tqdm import tqdm from tqdm import tqdm
l_tqdm = True l_tqdm = True
...@@ -120,18 +121,18 @@ class Scores: ...@@ -120,18 +121,18 @@ class Scores:
Class to calculate scores and skill scores. Class to calculate scores and skill scores.
""" """
known_scores = ["mse", "psnr"] known_scores = ["mse", "psnr","ssim"]
def __init__(self, score_name: str, dims: List[str]): def __init__(self, score_name: str, dims: List[str]):
""" """
Initialize score instance. Initialize score instance.
:param score_name: name of score taht is queried :param score_name: name of score that is queried
:param dims: list of dimension over which the score shall operate :param dims: list of dimension over which the score shall operate
:return: Score instance :return: Score instance
""" """
method = Scores.__init__.__name__ method = Scores.__init__.__name__
self.metrics_dict = {"mse": self.calc_mse_batch , "psnr": self.calc_psnr_batch} self.metrics_dict = {"mse": self.calc_mse_batch , "psnr": self.calc_psnr_batch, "ssim":self.calc_ssim_batch}
if set(self.metrics_dict.keys()) != set(Scores.known_scores): if set(self.metrics_dict.keys()) != set(Scores.known_scores):
raise ValueError("%{0}: Known scores must coincide with keys of metrics_dict.".format(method)) raise ValueError("%{0}: Known scores must coincide with keys of metrics_dict.".format(method))
self.score_name = self.set_score_name(score_name) self.score_name = self.set_score_name(score_name)
...@@ -193,10 +194,10 @@ class Scores: ...@@ -193,10 +194,10 @@ class Scores:
def calc_psnr_batch(self, data_fcst, data_ref, **kwargs): def calc_psnr_batch(self, data_fcst, data_ref, **kwargs):
""" """
Calculate mse of forecast data w.r.t. reference data Calculate psnr of forecast data w.r.t. reference data
:param data_fcst: forecasted data (xarray with dimensions [batch, lat, lon]) :param data_fcst: forecasted data (xarray with dimensions [batch, lat, lon])
:param data_ref: reference data (xarray with dimensions [batch, lat, lon]) :param data_ref: reference data (xarray with dimensions [batch, lat, lon])
:return: averaged mse for each batch example :return: averaged psnr for each batch example
""" """
method = Scores.calc_mse_batch.__name__ method = Scores.calc_mse_batch.__name__
...@@ -215,3 +216,16 @@ class Scores: ...@@ -215,3 +216,16 @@ class Scores:
return psnr return psnr
def calc_ssim_batch(self, data_fcast, data_ref, **kwargs):
"""
Calculate ssim ealuation metric of forecast data w.r.t reference data
:param data_fcst: forecasted data (xarray with dimensions [batch, lat, lon])
:param data_ref: reference data (xarray with dimensions [batch, lat, lon])
:return: averaged ssim for each batch example
"""
method = Scores.calc_ssim_batch.__name__
ssim_pred = ssim(data_ref, data_fcast,
data_range = data_fcast.max() - data_fcast.min())
return ssim_pred
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment