diff --git a/Jupyter_Notebooks/calc_climatolgical_mean.ipynb b/Jupyter_Notebooks/calc_climatolgical_mean.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..0094d2fe5617b0bb2af9241b21b5feefc3294afb --- /dev/null +++ b/Jupyter_Notebooks/calc_climatolgical_mean.ipynb @@ -0,0 +1,313 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 15, + "id": "annoying-jamaica", + "metadata": {}, + "outputs": [], + "source": [ + "import os, sys, time\n", + "import xarray as xr\n", + "import pandas as pd\n", + "import datetime as dt\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "legislative-portugal", + "metadata": {}, + "outputs": [], + "source": [ + "datadir = \"/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly\"\n", + "\n", + "datafile = \"1970-1999_t2m.nc\"\n", + "\n", + "datafile= os.path.join(datadir, datafile)\n", + "\n", + "datafile=\"/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/t2m_1970_1999.nc\"" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "through-cornell", + "metadata": {}, + "outputs": [], + "source": [ + "with xr.open_dataset(datafile) as dfile:\n", + " t2m_all = dfile[\"var167\"]\n", + " coords = t2m_all.coords" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "steady-implement", + "metadata": {}, + "outputs": [], + "source": [ + "ntimes = len(coords[\"time\"])\n", + "\n", + "t2m_all = t2m_all.chunk({\"time\": ntimes, \"lat\":100, \"lon\":100})" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "universal-neutral", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n", + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n", + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n", + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n", + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n", + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n", + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Registering averaging took 1.08\n", + "Performing averaging took 1.08\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n", + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n", + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n", + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n", + "/p/software/hdfml/stages/2020/software/Jupyter/2020.2.6-gcccoremkl-9.3.0-2020.2.254-Python-3.8.5/lib/python3.8/site-packages/xarray/core/indexing.py:1369: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", + "chunk and silence this warning, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", + " ... array[indexer]\n", + "\n", + "To avoid creating the large chunks, set the option\n", + " >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", + " ... array[indexer]\n", + " return self.array[key]\n" + ] + } + ], + "source": [ + "# define a function with the hourly calculation:\n", + "def hour_mean(x):\n", + " return x.groupby('time.hour').mean('time')\n", + "\n", + "time0 = time.time()\n", + "t2m_hourly = t2m_all.groupby(\"time.month\").apply(hour_mean)\n", + "\n", + "print(\"Registering averaging took {0:.2f}\".format(time.time()-time0))\n", + "\n", + "#print(t2m_hourly.values)\n", + "\n", + "print(\"Performing averaging took {0:.2f}\".format(time.time()-time0))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "signed-edmonton", + "metadata": {}, + "outputs": [], + "source": [ + "t2m_hourly = t2m_hourly.compute()" + ] + }, + { + "cell_type": "markdown", + "id": "sustainable-significance", + "metadata": {}, + "source": [ + "This works, but it takes about 3 minutes to process 30 years of data. <br>\n", + "However, the same operation is possible with CDO and only takes 36s to finish on Juwels. <br> \n", + "The two following shell commands (after loading CDO 1.9.8 and ecCodes 2.18.0) are:\n", + "```\n", + "clim_files=($(for year in {1991..2020}; do echo \"${year}_t2m.grb\"; done))\n", + "cdo -t ecmwf -f nc ensavg ${clim_files[@]} mutilyears_1991-2020.nc\n", + "```\n", + "In the following, we check the correctness of the data by computing the difference btween the data from a CDO-generated file against the data produced above. We choose the mean temperature in January at 12 UTC as an example." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "rental-suite", + "metadata": {}, + "outputs": [], + "source": [ + "datafile_cdo = os.path.join(datadir, \"climatology_t2m_1970-1999.nc\")\n", + "\n", + "with xr.open_dataset(datafile_cdo) as dfile:\n", + " t2m_hourly_cdo = dfile[\"T2M\"]\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "better-adventure", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<xarray.DataArray ()>\n", + "array(0.00097656, dtype=float32)\n", + "Coordinates:\n", + " hour int64 12\n", + " month int64 1\n", + " time datetime64[ns] 1979-01-01T12:00:00\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "test1 = t2m_hourly.sel(month=1, hour=12)\n", + "test2 = t2m_hourly_cdo.sel(time=\"1979-01-01 12:00\")\n", + "\n", + "diff = np.abs(test1-test2)\n", + "\n", + "print(np.max(diff))" + ] + }, + { + "cell_type": "markdown", + "id": "weird-sociology", + "metadata": {}, + "source": [ + "Thus, the maximum difference is in the $\\mathcal{O} (10^{-3})$ which can be neglected for our application." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "running-monday", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/video_prediction_tools/data_preprocess/calc_climatology.py b/video_prediction_tools/data_preprocess/calc_climatology.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd3f54735da18e4a2b03dffc001c9364ff96395 --- /dev/null +++ b/video_prediction_tools/data_preprocess/calc_climatology.py @@ -0,0 +1,104 @@ + +__email__ = "b.gong@fz-juelich.de" +__author__ = "Yan Ji, Bing Gong" +__date__ = "2021-05-26" +""" +Use the monthly average vaules to calculate the climitological values from the datas sources that donwload from ECMWF : /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/grib/monthly +""" + +import os +import time +import glob +import json + + +class Calc_climatology(object): + def __init__(self, input_dir="/p/fastdata/slmet/slmet111/met_data/ecmwf/era5/grib/monthly", output_clim=None, region:list=[10,20], json_path=None): + self.input_dir = input_dir + self.output_clim = output_clim + self.region = region + self.lats = None + self.lons = None + self.json_path = json_path + + @staticmethod + def calc_avg_per_year(grib_fl: str=None): + """ + :param grib_fl: the relative path of the monthly average values + :return: None + """ + return None + + def cal_avg_all_files(self): + + """ + Average by time for all the grib files + :return: None + """ + + method = Calc_climatology.cal_avg_all_files.__name__ + + multiyears_path = os.path.join(self.output_clim,"multiyears.grb") + climatology_path = os.path.join(self.output_clim,"climatology.grb") + grib_files = glob.glob(os.path.join(self.input_dir,"*t2m.grb")) + if grib_files: + print ("{0}: There are {1} monthly years grib files".format(method,len(grib_files))) + # merge all the files into one grib file + os.system("cdo mergetime {0} {1}".format(grib_files,multiyears_path)) + # average by time + os.system("cdo timavg {0} {1}".format(multiyears_path,climatology_path)) + + else: + FileExistsError("%{0}: The monthly years grib files do not exit in the input directory %{1}".format(method,self.input_dir)) + + return None + + + def get_lat_lon_from_json(self): + """ + Get lons and lats from json file + :return: list of lats, and lons + """ + method = Calc_climatology.get_lat_lon_from_json.__name__ + + if not os.path.exists(self.json_path): + raise FileExistsError("{0}: The json file {1} does not exist".format(method,self.json_path)) + + with open(self.json_path) as fl: + meta_data = json.load(fl) + + if "coordinates" not in list(meta_data.keys()): + raise KeyError("{0}: 'coordinates' is not one of the keys for metadata,json".format(method)) + else: + meta_coords = meta_data["coordinates"] + self.lats = meta_coords["lat"] + self.lons = meta_coords["lon"] + return self.lats, self.lons + + + def get_region_climate(self): + """ + Get the climitollgical values from the selected regions + :return: None + """ + pass + + + def __call__(self, *args, **kwargs): + + if os.path.exists(os.path.join(self.output_clim,"climatology.grb")): + pass + else: + self.cal_avg_all_files() + self.get_lat_lon_from_json() + self.get_region_climate() + + + + + +if __name__ == '__main__': + exp = Calc_climatology(json_path="/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/metadata.json") + exp() + + diff --git a/video_prediction_tools/data_preprocess/prepare_era5_data.py b/video_prediction_tools/data_preprocess/prepare_era5_data.py index 4f9e4c8246e36a83b8353cd00366b3e7ed83761c..fbb71bdb8b9df163d38a2bfbd06a8e8ecea4588d 100644 --- a/video_prediction_tools/data_preprocess/prepare_era5_data.py +++ b/video_prediction_tools/data_preprocess/prepare_era5_data.py @@ -60,7 +60,7 @@ class ERA5DataExtraction(object): temp_path = os.path.join(self.target_dir, self.year, month) os.makedirs(temp_path, exist_ok=True) - for var,value in self.varslist_surface.items(): + for var, value in self.varslist_surface.items(): # surface variables infile = os.path.join(self.src_dir, self.year, month, self.year+month+day+hour+'_sf.grb') outfile_sf = os.path.join(self.target_dir, self.year, month, self.year+month+day+hour+'_'+var+'.nc') diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index e7d982c12acaddb9352299240181152ef880e522..552c2cbbf30f96898f0e566884c49022e919643a 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -18,9 +18,8 @@ import datetime as dt import json from typing import Union, List # own modules -from general_utils import get_era5_varatts from normalization import Norm_data -from general_utils import check_dir +from general_utils import get_era5_varatts, check_dir from metadata import MetaData as MetaData from main_scripts.main_train_models import * from data_preprocess.preprocess_data_step2 import * @@ -30,9 +29,10 @@ from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, crea class Postprocess(TrainModel): - def __init__(self, results_dir=None, checkpoint=None, mode="test", batch_size=None, num_stochastic_samples=1, - stochastic_plot_id=0, gpu_mem_frac=None, seed=None, channel=0, args=None, run_mode="deterministic", - eval_metrics=None): + def __init__(self, results_dir: str = None, checkpoint: str= None, mode: str = "test", batch_size: int = None, + num_stochastic_samples: int = 1, stochastic_plot_id: int = 0, gpu_mem_frac: float = None, + seed: int = None, channel: int = 0, args=None, run_mode: str = "deterministic", + eval_metrics: List = ("mse", "psnr", "ssim")): """ Initialization of the class instance for postprocessing (generation of forecasts from trained model + basic evauation). @@ -443,7 +443,7 @@ class Postprocess(TrainModel): self.stochastic_loss_all_batches = np.mean(np.array(self.stochastic_loss_all_batches), axis=0) assert len(np.array(self.persistent_loss_all_batches).shape) == 1 assert np.array(self.persistent_loss_all_batches).shape[0] == self.future_length - print("Bug here:", np.array(self.stochastic_loss_all_batches).shape) + assert len(np.array(self.stochastic_loss_all_batches).shape) == 2 assert np.array(self.stochastic_loss_all_batches).shape[0] == self.num_stochastic_samples @@ -597,7 +597,7 @@ class Postprocess(TrainModel): # dictionary of implemented evaluation metrics dims = ["lat", "lon"] - known_eval_metrics = {"mse": Scores("mse", dims), "psnr": Scores("psnr", dims)} + known_eval_metrics = {"mse": Scores("mse", dims), "psnr": Scores("psnr", dims),"ssim": Scores("ssim",dims)} # generate list of functions that calculate requested evaluation metrics if set(self.eval_metrics).issubset(known_eval_metrics.keys()): diff --git a/video_prediction_tools/model_modules/video_prediction/metrics.py b/video_prediction_tools/model_modules/video_prediction/metrics.py index 61c13e91b74f1f53f01ab03fa65467aafb5844b2..f88fa38b9771a29c1ed5ec88ed9db417f270e31e 100644 --- a/video_prediction_tools/model_modules/video_prediction/metrics.py +++ b/video_prediction_tools/model_modules/video_prediction/metrics.py @@ -1,7 +1,8 @@ import tensorflow as tf #import lpips_tf -import numpy as np import math +from skimage.measure import compare_ssim as ssim_ski + def mse(a, b): return tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1]) @@ -23,6 +24,7 @@ def psnr_imgs(img1, img2, pixel_max=1.): def mse_imgs(image1,image2): mse = ((image1 - image2)**2).mean(axis=None) return mse + # def lpips(input0, input1): # if input0.shape[-1].value == 1: # input0 = tf.tile(input0, [1] * (input0.shape.ndims - 1) + [3]) @@ -32,9 +34,14 @@ def mse_imgs(image1,image2): # distance = lpips_tf.lpips(input0, input1) # return -distance -def ssim_images(image1,image2): +def ssim_images(image1, image2): """ - + Reference for calculating ssim Numpy impelmeentation for ssim https://cvnote.ddlee.cc/2019/09/12/psnr-ssim-python + https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html + :param image1 the reference images + :param image2 the predicte images """ - pass + ssim_pred = ssim_ski(image1, image2, + data_range = image2.max() - image2.min()) + return ssim_pred diff --git a/video_prediction_tools/postprocess/statistical_evaluation.py b/video_prediction_tools/postprocess/statistical_evaluation.py index 6e5da00dd06a1fc10c953dd76fb5053de66a4e34..9c2e885de6d85f698fd7ce3471b123228c7e5828 100644 --- a/video_prediction_tools/postprocess/statistical_evaluation.py +++ b/video_prediction_tools/postprocess/statistical_evaluation.py @@ -8,8 +8,9 @@ __date__ = "2021-05-xx" import numpy as np import xarray as xr -import pandas as pd from typing import Union, List +from skimage.measure import compare_ssim as ssim + try: from tqdm import tqdm l_tqdm = True @@ -134,7 +135,7 @@ def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length seed: int = 42): """ Performs block bootstrapping on metric along given dimension (e.g. along time dimension) - :param metric: DataArray or dataset of metric that should be bootstrapped + :param metric: DataArray or dataset of metric that should be bootstrapped :param dim_name: name of the dimension on which division into blocks is applied :param block_length: length of block (index-based) :param nboots_block: number of bootstrapping steps to be performed @@ -200,18 +201,18 @@ class Scores: Class to calculate scores and skill scores. """ - known_scores = ["mse", "psnr"] + known_scores = ["mse", "psnr","ssim"] def __init__(self, score_name: str, dims: List[str]): """ Initialize score instance. - :param score_name: name of score taht is queried - :param dims: list of dimension over which the score shall operate - :return: Score instance + :param score_name: name of score that is queried + :param dims: list of dimension over which the score shall operate + :return: Score instance """ method = Scores.__init__.__name__ - self.metrics_dict = {"mse": self.calc_mse_batch , "psnr": self.calc_psnr_batch} + self.metrics_dict = {"mse": self.calc_mse_batch , "psnr": self.calc_psnr_batch, "ssim":self.calc_ssim_batch} if set(self.metrics_dict.keys()) != set(Scores.known_scores): raise ValueError("%{0}: Known scores must coincide with keys of metrics_dict.".format(method)) self.score_name = self.set_score_name(score_name) @@ -273,10 +274,10 @@ class Scores: def calc_psnr_batch(self, data_fcst, data_ref, **kwargs): """ - Calculate mse of forecast data w.r.t. reference data + Calculate psnr of forecast data w.r.t. reference data :param data_fcst: forecasted data (xarray with dimensions [batch, lat, lon]) :param data_ref: reference data (xarray with dimensions [batch, lat, lon]) - :return: averaged mse for each batch example + :return: averaged psnr for each batch example """ method = Scores.calc_mse_batch.__name__ @@ -295,3 +296,17 @@ class Scores: return psnr + def calc_ssim_batch(self, data_fcast, data_ref, **kwargs): + """ + Calculate ssim ealuation metric of forecast data w.r.t reference data + :param data_fcst: forecasted data (xarray with dimensions [batch, lat, lon]) + :param data_ref: reference data (xarray with dimensions [batch, lat, lon]) + :return: averaged ssim for each batch example + """ + method = Scores.calc_ssim_batch.__name__ + print("shape of data_ref:",np.array(data_ref).shape) + print("shape of data_fcast:",np.array(data_fcast).shape) + print("max values of data forecast",data_fcast.max()) + print("min value of data forecadst",data_fcast.min()) + ssim_pred = ssim(data_ref[0,0,:,:], data_fcast[0,0,:,:]) + return ssim_pred