diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3a0299e121bebf0b7df05492be71ad0a877c0703 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +# Benchmark dataset for statistical downscaling with deep neural networks + +This repository aims to provide a benchmark dataset for statistical downscaling of meteorological fields with deep neural networks. +The work is pursued in scope of the [MAELSTROM project](https://www.maelstrom-eurohpc.eu/) which aims to develop new machine learning applications for weather and climate under the cooridniantion of ECMWF. <br> +The benchmark dataset is based on two reanalysis datasets that are well established and quality-controlled in meteorology: The ERA5-reanalysis data serves as the input dataset, whereas the COSMO-REA6 datasets provides the target data for the downscaling. With a grid spacing of 0.275° of the input data compared to 0.0625° of the target data, a downscaling factor equals to 4 in the benchmark dataset. Furthermore, different specific downscaling tasks adapted from the literature are provided. These pertain the following meteorlogical quantities: +- 2m temerature +- 10m wind +- solar irradiance diff --git a/downscaling_ap5/config/config_ds_tier2.json b/downscaling_ap5/config/config_ds_tier2.json index e47fc5f6efb8d97d5a9b72ca54d930b92312373c..6c63b7da6e0045c59410740606c60c05ec75d997 100644 --- a/downscaling_ap5/config/config_ds_tier2.json +++ b/downscaling_ap5/config/config_ds_tier2.json @@ -1,4 +1,5 @@ { + "num_files": 33, "norm_dims": ["time", "rlat", "rlon"], "batch_size": 32, "var_tar2in": "hsurf_tar", diff --git a/downscaling_ap5/config/config_ds_tier2_wind.json b/downscaling_ap5/config/config_ds_tier2_wind.json index 9591a95ef1502d85659e42f285a8ff5b2bce4ef7..5130555bd702ebeaaa206662c4455fd289f6bd27 100644 --- a/downscaling_ap5/config/config_ds_tier2_wind.json +++ b/downscaling_ap5/config/config_ds_tier2_wind.json @@ -2,5 +2,6 @@ "norm_dims": ["time", "rlat", "rlon"], "batch_size": 32, "var_tar2in": "hsurf_tar", - "predictands": ["u_10m_tar", "v_10m_tar", "hsurf_tar"] + "predictands": ["u_10m_tar", "v_10m_tar", "hsurf_tar"], + "recon_loss": "mae_vec" } diff --git a/downscaling_ap5/config/config_wgan.json b/downscaling_ap5/config/config_wgan.json index c9492c9eedd662e1023e4869ab9c1acf1d56a3c9..ab9a0958798f823a51dd7ec37d0e1582d9680285 100644 --- a/downscaling_ap5/config/config_wgan.json +++ b/downscaling_ap5/config/config_wgan.json @@ -1,11 +1,8 @@ { "nepochs": 60, "d_steps": 5, - "z_branch": true, - "lr_gen": 5e-05, - "lr_critic": 1e-06, - "lr_gen_end": 5e-06, "lr_decay": true, - "named_targets": false - + "named_targets": false, + "hparams_generator": {"lr": 5e-05, "lr_end": 5e-06, "z_branch": true, "activation": "swish", "l_avgpool": true}, + "hparams_critic": {"lr": 1e-06, "activation": "swish"} } diff --git a/downscaling_ap5/config/config_wgan_test.json b/downscaling_ap5/config/config_wgan_test.json index ec40059eb701b6a8076702b6b29d5cbe15622c6e..9c195b852f183159c71c51323ef0f46d63c7b504 100644 --- a/downscaling_ap5/config/config_wgan_test.json +++ b/downscaling_ap5/config/config_wgan_test.json @@ -1,11 +1,8 @@ { "nepochs": 5, "d_steps": 5, - "z_branch": true, - "lr_gen": 5e-05, - "lr_critic": 1e-06, - "lr_gen_end": 5e-06, "lr_decay": true, - "named_targets": false - + "named_targets": false, + "hparams_generator": {"lr": 5e-05, "lr_end": 5e-06, "z_branch": true, "activation": "swish", "l_avgpool": true}, + "hparams_critic": {"lr": 1e-06, "activation": "swish"} } diff --git a/downscaling_ap5/env_setup/modules_jsc.sh b/downscaling_ap5/env_setup/modules_jsc.sh index 15cbca28e674b16c87d5cc6c397ee8c282ce9a0b..88b43745487082603d8b53453a97e411b23318f5 100755 --- a/downscaling_ap5/env_setup/modules_jsc.sh +++ b/downscaling_ap5/env_setup/modules_jsc.sh @@ -20,9 +20,11 @@ if [[ 0 == 0 ]]; then # Restoring from model collection currently throws MPI-se ml GCC/11.2.0 ml ParaStationMPI/5.5.0-1 ml mpi4py/3.1.3 + ml tqdm/4.62.3 ml CDO/2.0.2 ml NCO/5.0.3 ml netcdf4-python/1.5.7-serial + ml scikit-image/0.18.3 ml SciPy-bundle/2021.10 ml xarray/0.20.1 ml dask/2021.9.1 diff --git a/downscaling_ap5/grid_des/crea6_reg_grid b/downscaling_ap5/grid_des/crea6_reg_grid new file mode 100644 index 0000000000000000000000000000000000000000..2c1dbc9a71d9727a7746cdd3df80fc8aad35e082 --- /dev/null +++ b/downscaling_ap5/grid_des/crea6_reg_grid @@ -0,0 +1,18 @@ +# +# Regular CREA6 grid for downscaling task +# +gridtype = lonlat +gridsize = 93312 +xsize = 432 +ysize = 216 +xname = lon +xlongname = "longitude" +xunits = "degrees" +yname = lat +ylongname = "latitude" +yunits = "degrees" +xfirst = -1.25 +xinc = 0.0625 +yfirst = 42.3125 +yinc = 0.0625 + diff --git a/downscaling_ap5/handle_data/handle_data_class.py b/downscaling_ap5/handle_data/handle_data_class.py index 4849de9bf506ca5e3b46d9e892d1ef38b56d765e..03b5030d5e160659330de129a2af026026619841 100644 --- a/downscaling_ap5/handle_data/handle_data_class.py +++ b/downscaling_ap5/handle_data/handle_data_class.py @@ -5,7 +5,7 @@ __author__ = "Michael Langguth" __email__ = "m.langguth@fz-juelich.de" __date__ = "2022-01-20" -__update__ = "2023-04-17" +__update__ = "2023-08-18" import os, glob from typing import List @@ -27,7 +27,7 @@ try: except: from multiprocessing.pool import ThreadPool from all_normalizations import ZScore -from other_utils import to_list, find_closest_divisor, free_mem +from other_utils import to_list, find_closest_divisor class HandleDataClass(object): @@ -321,7 +321,10 @@ class HandleDataClass(object): data_iter = data_iter.repeat() # clean-up to free some memory + # free_mem([da, da_in, da_tar, varnames_tar]) del da + del da_in + del da_tar gc.collect() return data_iter @@ -383,6 +386,11 @@ def get_dataset_filename(datadir: str, dataset_name: str, subset: str, laugmente if subset == "train": fname_suffix = f"{fname_suffix}*" if laugmented: raise ValueError("No augmented dataset available for Tier-2.") + elif dataset_name == "atmorep": + fname_suffix = f"{fname_suffix}_{dataset_name}_{subset}" + if subset == "train": + fname_suffix = f"{fname_suffix}*" + if laugmented: raise ValueError("No augmented dataset available for AtmoRep.") else: raise ValueError(f"Unknown dataset '{dataset_name}' passed.") @@ -431,7 +439,7 @@ class StreamMonthlyNetCDF(object): self.predictor_list = selected_predictors self.predictand_list = selected_predictands self.n_predictands, self.n_predictors = len(self.predictand_list), len(self.predictor_list) - self.all_vars = self.predictor_list + self.predictand_list + self.all_vars = self.predictor_list + self.predictand_list # ordering important to ensure that predictors come first! self.ds_all = xr.open_mfdataset(list(self.file_list), decode_cf=False, cache=False) # , parallel=True) self.var_tar2in = var_tar2in if self.var_tar2in is not None: @@ -463,7 +471,7 @@ class StreamMonthlyNetCDF(object): return self.nsamples def getitems(self, indices): - da_now = self.data_now.isel({self.sample_dim: indices}).to_array("variables") + da_now = self.data_now.isel({self.sample_dim: indices}).to_array("variables").sel({"variables": self.all_vars}) if self.var_tar2in is not None: # NOTE: * The order of the following operation must be the same as in make_tf_dataset_allmem # * The following operation order must concatenate var_tar2in by da_in to ensure @@ -600,7 +608,7 @@ class StreamMonthlyNetCDF(object): if all(stat_list): selected_vars = var_list else: - miss_inds = [i for i, x in enumerate(stat_list) if x] + miss_inds = [i for i, x in enumerate(stat_list) if not x] miss_vars = [var_list[i] for i in miss_inds] raise ValueError(f"Could not find the following variables in the dataset: {*miss_vars,}") @@ -660,8 +668,6 @@ class StreamMonthlyNetCDF(object): data_now = xr.concat([data_now, ds_add], dim=self.sample_dim) print(f"Appending data with {add_samples:d} samples took {timer() - t1:.2f}s" + f"(total #samples: {data_now.dims[self.sample_dim]})") - # free memory - free_mem([ds_add, add_samples, istart]) self.data_loaded[il] = data_now # timing @@ -670,8 +676,6 @@ class StreamMonthlyNetCDF(object): self.ds_proc_size += data_now.nbytes print(f"Dataset #{set_ind:d} ({il+1:d}/2) reading time: {t_read:.2f}s.") self.iload_next = il + 1 - # free memory - free_mem([nsamples, t_read, data_now]) return il diff --git a/downscaling_ap5/main_scripts/main_download_era5.py b/downscaling_ap5/main_scripts/main_download_era5.py new file mode 100644 index 0000000000000000000000000000000000000000..ffb04ac9d73ebe11ffbb96319f6483d3f90a7db3 --- /dev/null +++ b/downscaling_ap5/main_scripts/main_download_era5.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: 2022 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) +# +# SPDX-License-Identifier: MIT + +""" +Script to download ERA5 data from the CDS API. +""" + +__author__ = "Michael Langguth" +__email__ = "m.langguth@fz-juelich.de" +__date__ = "2023-11-22" +__update__ = "2023-08-22" + +# import modules +import os, sys +import json as js +import logging +import argparse +from download_era5_data import ERA5_Data_Loader + +# get logger +logger = logging.getLogger(os.path.basename(__file__).rstrip(".py")) +logger.setLevel(logging.INFO) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: %(message)s') + +def main(parser_args): + + data_dir = parser_args.data_dir + + # read configuration files for model and dataset + with parser_args.data_req_file as fdreq: + req_dict = js.load(fdreq) + + + # create output directory + if not os.path.exists(data_dir): + os.makedirs(data_dir) + + # create logger handlers + logfile = os.path.join(data_dir, f"download_era5_{parser_args.exp_name}.log") + if os.path.isfile(logfile): os.remove(logfile) + fh = logging.FileHandler(logfile) + ch = logging.StreamHandler(sys.stdout) + ch.setLevel(logging.DEBUG) + fh.setLevel(logging.INFO) + + fh.setFormatter(formatter) + ch.setFormatter(formatter) + + logger.addHandler(fh), logger.addHandler(ch) + + # create data loader instance + data_loader = ERA5_Data_Loader(parser_args.nworkers) + + # download data + _ = data_loader(req_dict, data_dir, parser_args.start, parser_args.end, parser_args.format) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_directory", "-data_dir", dest="data_dir", type=str, required=True, + help="Directory where test dataset (netCDF-file) is stored.") + parser.add_argument("--data_request_file", "-data_req_file", dest="data_req_file", type=argparse.FileType("r"), required=True, + help="File containing data request information for the CDS API.") + parser.add_argument("--year_start", "-start", dest="start", type=int, default=1995, + help="Start year of ERA5-data request.") + parser.add_argument("--year_end", "-end", dest="end", type=int, default=2019, + help="End year of ERA5-data request.") + parser.add_argument("--nworkers", "-nw", dest="nowrkers", type=int, default=4, + help="Number of workers to download ERA5 data.") + parser.add_argument("--format", "-format", dest="format", type=str, default="netcdf", + help="Format of downloaded data.") + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/downscaling_ap5/main_scripts/main_postprocess.py b/downscaling_ap5/main_scripts/main_postprocess.py index 0e60fbd7fbf19a05c6e0c88cb65bc4b66cceca43..e9735171e6fc7050467e8691d7cb49f39d845fe7 100644 --- a/downscaling_ap5/main_scripts/main_postprocess.py +++ b/downscaling_ap5/main_scripts/main_postprocess.py @@ -9,23 +9,28 @@ Driver-script to perform inference on trained downscaling models. __author__ = "Michael Langguth" __email__ = "m.langguth@fz-juelich.de" __date__ = "2022-12-08" -__update__ = "2022-12-08" +__update__ = "2023-08-21" import os, sys, glob import logging import argparse from timeit import default_timer as timer import json as js +from datetime import datetime as dt +#import datetime as dt +import gc import numpy as np import xarray as xr import tensorflow.keras as keras import matplotlib as mpl +import cartopy.crs as ccrs from handle_data_unet import * from handle_data_class import HandleDataClass, get_dataset_filename from all_normalizations import ZScore from statistical_evaluation import Scores -from postprocess import get_model_info, run_evaluation_time, run_evaluation_spatial -from datetime import datetime as dt +from postprocess import get_model_info, run_evaluation_time, run_evaluation_spatial, run_feature_importance +from model_utils import convert_to_xarray +#from other_utils import free_mem # get logger logger = logging.getLogger(os.path.basename(__file__).rstrip(".py")) @@ -33,6 +38,7 @@ logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: %(message)s') + def main(parser_args): t0 = timer() @@ -59,7 +65,8 @@ def main(parser_args): logger.addHandler(fh), logger.addHandler(ch) - logger.info(f"Start postprocessing at {dt.now().strftime('%Y-%m-%d %H:%M:%S')}") + #logger.info(f"Start postprocessing at {dt.now().strftime('%Y-%m-%d %H:%M:%S')}") + logger.info(f"Start postprocessing at...") # read configuration files md_config_pattern, ds_config_pattern = f"config_{model_type}.json", f"config_ds_{parser_args.dataset}.json" @@ -105,41 +112,46 @@ def main(parser_args): logger.info(f"Variable {tar_varname} serves as ground truth data.") with xr.open_dataset(fdata_test) as ds_test: - ground_truth = ds_test[tar_varname].astype("float32", copy=False) + ground_truth = ds_test[tar_varname].astype("float32", copy=True) ds_test = norm.normalize(ds_test) # prepare training and validation data logger.info(f"Start preparing test dataset...") t0_preproc = timer() - da_test = HandleDataClass.reshape_ds(ds_test.astype("float32", copy=False)) - tfds_test = HandleDataClass.make_tf_dataset_allmem(da_test.astype("float32", copy=True), ds_dict["batch_size"], - ds_dict["predictands"], lshuffle=False, var_tar2in=ds_dict["var_tar2in"], - named_targets=named_targets, lrepeat=False, drop_remainder=False) + da_test = HandleDataClass.reshape_ds(ds_test).astype("float32", copy=True) + + # clean-up to reduce memory footprint + del ds_test + gc.collect() + #free_mem([ds_test]) + + tfds_opts = {"batch_size": ds_dict["batch_size"], "predictands": ds_dict["predictands"], "predictors": ds_dict.get("predictors", None), + "lshuffle": False, "var_tar2in": ds_dict["var_tar2in"], "named_targets": named_targets, "lrepeat": False, "drop_remainder": False} - # perform normalization - da_test_in, da_test_tar = HandleDataClass.split_in_tar(da_test, predictands=ds_dict["predictands"]) + tfds_test = HandleDataClass.make_tf_dataset_allmem(da_test, **tfds_opts) + + predictors = ds_dict.get("predictors", None) + if predictors is None: + predictors = [var for var in list(da_test["variables"].values) if var.endswith("_in")] + if ds_dict.get("var_tar2in", False): predictors.append(ds_dict["var_tar2in"]) # start inference logger.info(f"Preparation of test dataset finished after {timer() - t0_preproc:.2f}s. " + "Start inference on trained model...") t0_train = timer() - y_pred_trans = trained_model.predict(tfds_test, verbose=2) + y_pred = trained_model.predict(tfds_test, verbose=2) logger.info(f"Inference on test dataset finished. Start denormalization of output data...") - # get coordinates and dimensions from target data - slice_dict = {"variables": 0} if hparams_dict["z_branch"] else {} - coords = da_test_tar.isel(slice_dict).squeeze().coords - dims = da_test_tar.isel(slice_dict).squeeze().dims - if hparams_dict["z_branch"]: - # slice data to get first channel only - if isinstance(y_pred_trans, list): y_pred_trans = y_pred_trans[0] - y_pred = xr.DataArray(y_pred_trans[..., 0].squeeze(), coords=coords, dims=dims) - else: - # no slicing required - y_pred = xr.DataArray(y_pred_trans.squeeze(), coords=coords, dims=dims) - # perform denormalization - y_pred = norm.denormalize(y_pred.squeeze(), varname=tar_varname) + + # clean-up to reduce memory footprint + del tfds_test + gc.collect() + #free_mem([tfds_test]) + + # convert to xarray + y_pred = convert_to_xarray(y_pred, norm, tar_varname, da_test.sel({"variables": tar_varname}).squeeze().coords, + da_test.sel({"variables": tar_varname}).squeeze().dims, hparams_dict["z_branch"]) # write inference data to netCDf logger.info(f"Write inference data to netCDF-file '{ncfile_out}'") @@ -155,7 +167,7 @@ def main(parser_args): logger.info("Start temporal evaluation...") t0_tplot = timer() - _ = run_evaluation_time(score_engine, "rmse", "K", plt_dir, value_range=(0., 3.), model_type=model_type) + rmse_all = run_evaluation_time(score_engine, "rmse", "K", plt_dir, value_range=(0., 3.), model_type=model_type) _ = run_evaluation_time(score_engine, "bias", "K", plt_dir, value_range=(-1., 1.), ref_line=0., model_type=model_type) _ = run_evaluation_time(score_engine, "grad_amplitude", "1", plt_dir, value_range=(0.7, 1.1), @@ -163,19 +175,43 @@ def main(parser_args): logger.info(f"Temporal evalutaion finished in {timer() - t0_tplot:.2f}s.") + # run feature importance analysis for RMSE + logger.info("Start feature importance analysis...") + t0_fi = timer() + + rmse_ref = rmse_all.mean().values + + _ = run_feature_importance(da_test, predictors, tar_varname, trained_model, norm, "rmse", rmse_ref, + tfds_opts, plt_dir, patch_size=(6, 6), variable_dim="variables") + + logger.info(f"Feature importance analysis finished in {timer() - t0_fi:.2f}s.") + + # clean-up to reduce memory footprint + del da_test + gc.collect() + #free_mem([da_test]) + # instantiate score engine with retained spatial dimensions score_engine = Scores(y_pred, ground_truth, []) + # ad-hoc adaption to projection basaed on norm_dims + if "rlat" in ds_dict["norm_dims"]: + proj=ccrs.RotatedPole(pole_longitude=-162.0, pole_latitude=39.25) + else: + proj=ccrs.PlateCarree() + logger.info("Start spatial evaluation...") lvl_rmse = np.arange(0., 3.1, 0.2) cmap_rmse = mpl.cm.afmhot_r(np.linspace(0., 1., len(lvl_rmse))) - _ = run_evaluation_spatial(score_engine, "rmse", os.path.join(plt_dir, "rmse_spatial"), cmap=cmap_rmse, - levels=lvl_rmse) + _ = run_evaluation_spatial(score_engine, "rmse", os.path.join(plt_dir, "rmse_spatial"), + dims=ds_dict["norm_dims"][1::], cmap=cmap_rmse, levels=lvl_rmse, + projection=proj) lvl_bias = np.arange(-2., 2.1, 0.1) cmap_bias = mpl.cm.seismic(np.linspace(0., 1., len(lvl_bias))) - _ = run_evaluation_spatial(score_engine, "bias", os.path.join(plt_dir, "bias_spatial"), cmap=cmap_bias, - levels=lvl_bias) + _ = run_evaluation_spatial(score_engine, "bias", os.path.join(plt_dir, "bias_spatial"), + dims=ds_dict["norm_dims"][1::], cmap=cmap_bias, levels=lvl_bias, + projection=proj) logger.info(f"Spatial evalutaion finished in {timer() - t0_tplot:.2f}s.") diff --git a/downscaling_ap5/main_scripts/main_train.py b/downscaling_ap5/main_scripts/main_train.py index 66970190c0e00b826422ada7c7ae0ec09ff5b4a4..04521802e9f79dc6b39da66ae46bd6e87d79a7f8 100644 --- a/downscaling_ap5/main_scripts/main_train.py +++ b/downscaling_ap5/main_scripts/main_train.py @@ -9,7 +9,7 @@ Driver-script to train downscaling models. __author__ = "Michael Langguth" __email__ = "m.langguth@fz-juelich.de" __date__ = "2022-10-06" -__update__ = "2023-04-17" +__update__ = "2023-08-18" import os import argparse @@ -24,16 +24,13 @@ from tensorflow.keras.utils import plot_model from all_normalizations import ZScore from model_utils import ModelEngine, TimeHistory, handle_opt_utils, get_loss_from_history from handle_data_class import HandleDataClass, get_dataset_filename -from other_utils import free_mem, print_gpu_usage, print_cpu_usage, copy_filelist -from benchmark_utils import BenchmarkCSV, get_training_time_dict +from other_utils import print_gpu_usage, print_cpu_usage, copy_filelist +from benchmark_utils import get_training_time_dict # Open issues: # * d_steps must be parsed with hparams_dict as model is uninstantiated at this point and thus no default parameters # are available -# * flag named_targets must be set to False in hparams_dict for WGAN to work with U-Net -# * ensure that dataset defaults are set (such as predictands for WGAN) -# * replacement of manual benchmarking by JUBE-benchmarking to avoid duplications def main(parser_args): # start timing @@ -67,9 +64,6 @@ def main(parser_args): data_norm, write_norm = None, True norm_dims = ds_dict["norm_dims"] - # initialize benchmarking object - bm_obj = BenchmarkCSV(os.path.join(os.getcwd(), f"benchmark_training_{parser_args.model}.csv")) - # get model instance and set-up batch size model_instance = ModelEngine(parser_args.model) # Note: bs_train is introduced to allow substepping in the training loop, e.g. for WGAN where n optimization steps @@ -90,7 +84,7 @@ def main(parser_args): # the dataset if "*" in fname_or_patt_train: ds_obj, tfds_train = HandleDataClass.make_tf_dataset_dyn(datadir, fname_or_patt_train, bs_train, nepochs, - 30, ds_dict["predictands"], + ds_dict["num_files"], ds_dict["predictands"], predictors=ds_dict.get("predictors", None), var_tar2in=ds_dict["var_tar2in"], named_targets=named_targets, @@ -100,7 +94,11 @@ def main(parser_args): tfds_train_size = ds_obj.dataset_size else: ds_train = xr.open_dataset(fname_or_patt_train) - da_train = HandleDataClass.reshape_ds(ds_train.astype("float32", copy=False)) + da_train = HandleDataClass.reshape_ds(ds_train).astype("float32", copy=True) + + # free up some memory + del ds_train + gc.collect() if not data_norm: # data_norm must be freshly instantiated (triggering later parameter retrieval) @@ -111,9 +109,14 @@ def main(parser_args): predictors=ds_dict.get("predictors", None), var_tar2in=ds_dict["var_tar2in"], named_targets=named_targets) + nsamples, shape_in = da_train.shape[0], tfds_train.element_spec[0].shape[1:].as_list() tfds_train_size = da_train.nbytes + # clean up to save some memory + del da_train + gc.collect() + if write_norm: data_norm.save_norm_to_file(os.path.join(model_savedir, "norm.json")) @@ -125,15 +128,17 @@ def main(parser_args): fdata_val = get_dataset_filename(datadir, dataset, "val", ds_dict.get("laugmented", False)) with xr.open_dataset(fdata_val) as ds_val: ds_val = data_norm.normalize(ds_val) - da_val = HandleDataClass.reshape_ds(ds_val) + da_val = HandleDataClass.reshape_ds(ds_val).astype("float32", copy=True) - tfds_val = HandleDataClass.make_tf_dataset_allmem(da_val.astype("float32", copy=True), ds_dict["batch_size"], + tfds_val = HandleDataClass.make_tf_dataset_allmem(da_val, ds_dict["batch_size"], ds_dict["predictands"], predictors=ds_dict.get("predictors", None), lshuffle=True, var_tar2in=ds_dict["var_tar2in"], named_targets=named_targets) # clean up to save some memory - free_mem([ds_val, da_val]) + del ds_val + del da_val + gc.collect() tval_load = timer() - t0_val print(f"Validation data preparation time: {tval_load:.2f}s.") @@ -177,21 +182,21 @@ def main(parser_args): ttrain_load = sum(ds_obj.reading_times) + tval_load print(f"Data loading time: {ttrain_load:.2f}s.") print(f"Average throughput: {ds_obj.ds_proc_size / 1.e+06 / training_times['Total training time']:.3f} MB/s") - benchmark_dict = {**{"data loading time": ttrain_load}, **training_times} - print(f"Model '{parser_args.exp_name}' training time: {training_times['Total training time']:.2f} s. " + - f"Save model to '{model_savedir}'") # save trained model t0_save = timer() + model_savedir_last = os.path.join(model_savedir, f"{parser_args.exp_name}_last") os.makedirs(model_savedir, exist_ok=True) - model.save(filepath=model_savedir) + model.save(filepath=model_savedir_last) - if callable(getattr(model, "plot_model", False)): - model.plot_model(model_savedir, show_shapes=True) - else: + #if callable(getattr(model, "plot_model", False)): + try: + model.plot_model(model_savedir, show_shapes=True) # , show_layer_actiavtions=True) + #else: + except: plot_model(model, os.path.join(model_savedir, f"plot_{parser_args.exp_name}.png"), - show_shapes=True) + show_shapes=True) #, show_layer_activations=True) # final timing tend = timer() @@ -203,36 +208,6 @@ def main(parser_args): print_gpu_usage("Final GPU memory: ") print_cpu_usage("Final CPU memory: ") - # populate benchmark dictionary - benchmark_dict.update({"saving model time": saving_time, "total runtime": tot_run_time}) - benchmark_dict.update({"job id": job_id, "#nodes": 1, "#cpus": len(os.sched_getaffinity(0)), "#gpus": 1, - "#mpi tasks": 1, "node id": None, "max. gpu power": None, "gpu energy consumption": None}) - try: - benchmark_dict["final training loss"] = get_loss_from_history(history, "loss") - except KeyError: - benchmark_dict["final training loss"] = get_loss_from_history(history, "recon_loss") - try: - benchmark_dict["final validation loss"] = get_loss_from_history(history, "val_loss") - except KeyError: - benchmark_dict["final validation loss"] = get_loss_from_history(history, "val_recon_loss") - # ... and save CSV-file with tracked data on disk - bm_obj.populate_csv_from_dict(benchmark_dict) - - js_file = os.path.join(model_savedir, "benchmark_training_static.json") - if not os.path.isfile(js_file): - func_model_info = getattr(model, "get_model_info", None) - if callable(func_model_info): - model_info = func_model_info() - else: - model_info = {} - stat_info = {"static_model_info": model_info, - "data_info": {"training data size": tfds_train_size, "validation data size": da_val.nbytes, - "nsamples": nsamples, "shape_samples": shape_in, - "batch_size": ds_dict["batch_size"]}} - - with open(js_file, "w") as jsf: - js.dump(stat_info, jsf) - print("Finished job at {0}".format(dt.strftime(dt.now(), "%Y-%m-%d %H:%M:%S"))) diff --git a/downscaling_ap5/models/custom_losses.py b/downscaling_ap5/models/custom_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..46e5e537a2a40d4d61657c90da2581fe651d6f4b --- /dev/null +++ b/downscaling_ap5/models/custom_losses.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: 2022 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) +# +# SPDX-License-Identifier: MIT + +""" +Some custmoized losses (e.g. on vector quantities) +""" + +__author__ = "Michael Langguth" +__email__ = "m.langguth@fz-juelich.de" +__date__ = "2023-06-16" +__update__ = "2023-06-16" + +# import module +import inspect +import tensorflow as tf + +def fix_channels(n_channels): + """ + Decorator to fix number of channels in loss functions. + """ + def decorator(func): + def wrapper(y_true, y_pred, **func_kwargs): + return func(y_true, y_pred, n_channels, **func_kwargs) + return wrapper + return decorator + +def get_custom_loss(loss_name, **kwargs): + """ + Loss factory including some customized losses and all available Keras losses + :param loss_name: name of the loss function + :return: the respective layer to deploy desired activation + """ + known_losses = ["mse_channels", "mae_channels", "mae_vec", "mse_vec", "critic", "critic_generator"] + \ + [loss_cls[0] for loss_cls in inspect.getmembers(tf.keras.losses, inspect.isclass)] + + loss_name = loss_name.lower() + + n_channels = kwargs.get("n_channels", None) + + if loss_name == "mse_channels": + assert n_channels > 0, f"n_channels must be a number larger than zero, but is {n_channels}." + loss_fn = fix_channels(**kwargs)(mse_channels) + elif loss_name == "mae_channels": + assert n_channels > 0, f"n_channels must be a number larger than zero, but is {n_channels}." + loss_fn = fix_channels(**kwargs)(mae_channels) + elif loss_name == "mae_vec": + assert n_channels > 0, f"n_channels must be a number larger than zero, but is {n_channels}." + loss_fn = fix_channels(**kwargs)(mae_vec) + elif loss_name == "mse_vec": + assert n_channels > 0, f"n_channels must be a number larger than zero, but is {n_channels}." + loss_fn = fix_channels(**kwargs)(mse_vec) + elif loss_name == "critic": + loss_fn = critic_loss + elif loss_name == "critic_generator": + loss_fn = critic_gen_loss + else: + try: + loss_fn = getattr(tf.keras.losses, loss_name)(**kwargs) + except AttributeError: + raise ValueError(f"{loss_name} is not a valid loss function. Choose one of the following: {known_losses}") + + return loss_fn + + +def mae_channels(x, x_hat, n_channels: int = None, channels_last: bool = True, avg_channels: bool = False): + rloss = 0. + if channels_last: + # get MAE for all output heads + for i in range(n_channels): + rloss += tf.reduce_mean(tf.abs(tf.squeeze(x_hat[..., i]) - x[..., i])) + else: + for i in range(n_channels): + rloss += tf.reduce_mean(tf.abs(tf.squeeze(x_hat[i, ...]) - x[i, ...])) + + if avg_channels: + rloss /= n_channels + + return rloss + +def mse_channels(x, x_hat, n_channels, channels_last: bool = True, avg_channels: bool = False): + rloss = 0. + if channels_last: + # get MAE for all output heads + for i in range(n_channels): + rloss += tf.reduce_mean(tf.square(tf.squeeze(x_hat[..., i]) - x[..., i])) + else: + for i in range(n_channels): + rloss += tf.reduce_mean(tf.square(tf.squeeze(x_hat[i, ...]) - x[i, ...])) + + if avg_channels: + rloss /= n_channels + + return rloss + +def mae_vec(x, x_hat, n_channels, channels_last: bool = True, avg_channels: bool = False, nd_vec: int = None): + + if nd_vec is None: + nd_vec = n_channels + + rloss = 0. + if channels_last: + vec_ind = -1 + diff = tf.squeeze(x_hat[..., 0:nd_vec]) - x[..., 0:nd_vec] + else: + vec_ind = 1 + diff = tf.squeeze(x_hat[:,0:nd_vec, ...]) - x[:,0:nd_vec, ...] + + rloss = tf.reduce_mean(tf.norm(diff, axis=vec_ind)) + #rloss = tf.reduce_mean(tf.math.reduce_euclidean_norm(diff, axis=vec_ind)) + + if nd_vec > n_channels: + if channels_last: + rloss += mae_channels(x[..., nd_vec::], x_hat[..., nd_vec::], True, avg_channels) + else: + rloss += mae_channels(x[:, nd_vec::, ...], x_hat[:, nd_vec::, ...], True, avg_channels) + + return rloss + +def mse_vec(x, x_hat, n_channels, channels_last: bool = True, avg_channels: bool = False, nd_vec: int = None): + + if nd_vec is None: + nd_vec = n_channels + + rloss = 0. + + if channels_last: + vec_ind = -1 + diff = tf.squeeze(x_hat[..., 0:nd_vec]) - x[..., 0:nd_vec] + else: + vec_ind = 1 + diff = tf.squeeze(x_hat[:,0:nd_vec, ...]) - x[:,0:nd_vec, ...] + + rloss = tf.reduce_mean(tf.square(tf.norm(diff, axis=vec_ind))) + + if nd_vec > n_channels: + if channels_last: + rloss += mse_channels(x[..., nd_vec::], x_hat[..., nd_vec::], True, avg_channels) + else: + rloss += mse_channels(x[:, nd_vec::, ...], x_hat[:, nd_vec::, ...], True, avg_channels) + + return rloss + +def critic_loss(critic_real, critic_gen): + """ + The critic is optimized to maximize the difference between the generated and the real data max(real - gen). + This is equivalent to minimizing the negative of this difference, i.e. min(gen - real) = max(real - gen) + :param critic_real: critic on the real data + :param critic_gen: critic on the generated data + :return c_loss: loss to optize the critic + """ + c_loss = tf.reduce_mean(critic_gen - critic_real) + + return c_loss + + +def critic_gen_loss(critic_gen): + cg_loss = -tf.reduce_mean(critic_gen) + + return cg_loss \ No newline at end of file diff --git a/downscaling_ap5/models/model_utils.py b/downscaling_ap5/models/model_utils.py index 5011e9cdd2eaeab3d6bdf965ed80d6c69b86d325..0ee63045c1ed817ff3bd927642ea14b5c897562f 100644 --- a/downscaling_ap5/models/model_utils.py +++ b/downscaling_ap5/models/model_utils.py @@ -9,10 +9,11 @@ Some auxiliary methods to create Keras models. __author__ = "Michael Langguth" __email__ = "m.langguth@fz-juelich.de" __date__ = "2022-05-26" -__update__ = "2023-03-10" +__update__ = "2023-10-12" # import modules from timeit import default_timer as timer +import xarray as xr import tensorflow.keras as keras from unet_model import sha_unet, UNET from wgan_model import WGAN, critic_model @@ -137,3 +138,28 @@ def handle_opt_utils(model: keras.Model, opt_funcname: str): return opt_dict +def convert_to_xarray(mout_np, norm, varname, coords, dims, z_branch=False): + """ + Converts numpy-array of model output to xarray.DataArray and performs denormalization. + :param mout_np: numpy-array of model output + :param norm: normalization object + :param varname: name of variable + :param coords: coordinates of target data + :param dims: dimensions of target data + :param z_branch: flag for z-branch + :return: xarray.DataArray of model output with denormalized data + """ + if z_branch: + # slice data to get first channel only + if isinstance(mout_np, list): mout_np = mout_np[0] + mout_xr = xr.DataArray(mout_np[..., 0].squeeze(), coords=coords, dims=dims, name=varname) + else: + # no slicing required + mout_xr = xr.DataArray(mout_np.squeeze(), coords=coords, dims=dims, name=varname) + + # perform denormalization + mout_xr = norm.denormalize(mout_xr, varname=varname) + + return mout_xr + + diff --git a/downscaling_ap5/models/unet_model.py b/downscaling_ap5/models/unet_model.py index 60ee0ffdc30929bd822cd352fe5f747e8c002655..0465a89f3fd17357dbf92b6d276c9347fbc4cf76 100644 --- a/downscaling_ap5/models/unet_model.py +++ b/downscaling_ap5/models/unet_model.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: 2022 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) +# SPDX-FileCopyrightText: 2023 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) # # SPDX-License-Identifier: MIT @@ -9,28 +9,29 @@ Methods to set-up U-net models incl. its building blocks. __author__ = "Michael Langguth" __email__ = "m.langguth@fz-juelich.de" __date__ = "2021-XX-XX" -__update__ = "2023-05-06" +__update__ = "2023-11-22" # import modules import os from typing import List +import inspect import numpy as np import tensorflow as tf import tensorflow.keras as keras # all the layers used for U-net from tensorflow.keras.layers import (Concatenate, Conv2D, Conv2DTranspose, Input, MaxPool2D, BatchNormalization, - Activation) + Activation, AveragePooling2D) from tensorflow.keras.models import Model from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint, EarlyStopping -import advanced_activations +from advanced_activations import advanced_activation from other_utils import to_list # building blocks for Unet def conv_block(inputs, num_filters: int, kernel: tuple = (3, 3), strides: tuple = (1, 1), padding: str = "same", - activation: str = "relu", activation_args=None, kernel_init: str = "he_normal", + activation: str = "swish", activation_args={}, kernel_init: str = "he_normal", l_batch_normalization: bool = True): """ A convolutional layer with optional batch normalization @@ -51,101 +52,139 @@ def conv_block(inputs, num_filters: int, kernel: tuple = (3, 3), strides: tuple try: x = Activation(activation)(x) except ValueError: - ac_layer = advanced_activations(activation, *activation_args) + ac_layer = advanced_activation(activation, *activation_args) x = ac_layer(x) return x -def conv_block_n(inputs, num_filters: int, n: int = 2, kernel: tuple = (3, 3), strides: tuple = (1, 1), - padding: str = "same", activation: str = "relu", activation_args=None, - kernel_init: str = "he_normal", l_batch_normalization: bool = True): +def conv_block_n(inputs, num_filters: int, n: int = 2, **kwargs): """ Sequential application of two convolutional layers (using conv_block). :param inputs: the input data with dimensions nx, ny and nc :param num_filters: number of filters (output channel dimension) :param n: number of convolutional blocks - :param kernel: tuple for convolution kernel size - :param strides: tuple for stride of convolution - :param padding: technique for padding (e.g. "same" or "valid") - :param activation: activation fuction for neurons (e.g. "relu") - :param activation_args: arguments for activation function given that advanced layers are applied - :param kernel_init: initialization technique (e.g. "he_normal" or "glorot_uniform") - :param l_batch_normalization: flag if batch normalization should be applied + :param kwargs: keyword arguments for conv_block """ - x = conv_block(inputs, num_filters, kernel, strides, padding, activation, activation_args, - kernel_init, l_batch_normalization) + x = conv_block(inputs, num_filters, **kwargs) for _ in np.arange(n - 1): - x = conv_block(x, num_filters, kernel, strides, padding, activation, activation_args, - kernel_init, l_batch_normalization) + x = conv_block(x, num_filters, **kwargs) return x -def encoder_block(inputs, num_filters, kernel_maxpool: tuple = (2, 2), l_large: bool = True): +def encoder_block(inputs, num_filters, l_large: bool = True, kernel_pool: tuple = (2, 2), l_avgpool: bool = False, **kwargs): """ One complete encoder-block used in U-net. :param inputs: input to encoder block :param num_filters: number of filters/channel to be used in convolutional blocks - :param kernel_maxpool: kernel used in max-pooling :param l_large: flag for large encoder block (two consecutive convolutional blocks) + :param kernel_maxpool: kernel used in max-pooling + :param l_avgpool: flag if average pooling is used instead of max pooling + :param kwargs: keyword arguments for conv_block """ if l_large: - x = conv_block_n(inputs, num_filters, n=2) + x = conv_block_n(inputs, num_filters, n=2, **kwargs) else: - x = conv_block(inputs, num_filters) + x = conv_block(inputs, num_filters, **kwargs) - p = MaxPool2D(kernel_maxpool)(x) + if l_avgpool: + p = AveragePooling2D(kernel_pool)(x) + else: + p = MaxPool2D(kernel_pool)(x) return x, p +def subpixel_block(inputs, num_filters, kernel: tuple = (3,3), upscale_fac: int = 2, + padding: str = "same", activation: str = "swish", activation_args: dict = {}, + kernel_init: str = "he_normal"): -def decoder_block(inputs, skip_features, num_filters, kernel: tuple = (3, 3), strides_up: int = 2, - padding: str = "same", activation="relu", kernel_init="he_normal", - l_batch_normalization: bool = True): + x = Conv2D(num_filters * (upscale_fac ** 2), kernel, padding=padding, kernel_initializer=kernel_init, + activation=None)(inputs) + try: + x = Activation(activation)(x) + except ValueError: + ac_layer = advanced_activation(activation, *activation_args) + x = ac_layer(x) + + x = tf.nn.depth_to_space(x, upscale_fac) + + return x + + + +def decoder_block(inputs, skip_features, num_filters, strides_up: int = 2, l_subpixel: bool = False, **kwargs_conv_block): """ One complete decoder block used in U-net (reverting the encoder) """ - x = Conv2DTranspose(num_filters, (strides_up, strides_up), strides=strides_up, padding="same")(inputs) + if l_subpixel: + kwargs_subpixel = kwargs_conv_block.copy() + for ex_key in ["strides", "l_batch_normalization"]: + kwargs_subpixel.pop(ex_key, None) + x = subpixel_block(inputs, num_filters, upscale_fac=strides_up, **kwargs_subpixel) + else: + x = Conv2DTranspose(num_filters, (strides_up, strides_up), strides=strides_up, padding="same")(inputs) + + activation = kwargs_conv_block.get("activation", "relu") + activation_args = kwargs_conv_block.get("activation_args", {}) + + try: + x = Activation(activation)(x) + except ValueError: + ac_layer = advanced_activation(activation, *activation_args) + x = ac_layer(x) + x = Concatenate()([x, skip_features]) - x = conv_block_n(x, num_filters, 2, kernel, (1, 1), padding, activation, kernel_init=kernel_init, - l_batch_normalization=l_batch_normalization) + x = conv_block_n(x, num_filters, 2, **kwargs_conv_block) return x # The particular U-net -def sha_unet(input_shape: tuple, n_predictands_dyn: int, channels_start: int = 56, z_branch: bool = False, - concat_out: bool = False, tar_channels=["output_dyn", "output_z"]) -> Model: +def sha_unet(input_shape: tuple, n_predictands_dyn: int, hparams_unet: dict, concat_out: bool = False, + tar_channels=["output_dyn", "output_z"]) -> Model: """ Builds up U-net model architecture adapted from Sha et al., 2020 (see https://doi.org/10.1175/JAMC-D-20-0057.1). :param input_shape: shape of input-data :param channels_start: number of channels to use as start in encoder blocks :param n_predictands: number of target variables (dynamic output variables) :param z_branch: flag if z-branch is used. - :param tar_channels: name of output/target channels (needed for associating losses during compilation) + :param advanced_unet: flag if advanced U-net is used (LeakyReLU instead of ReLU, average pooling instead of max pooling and subpixel-layer) :param concat_out: boolean if output layers will be concatenated (disables named target channels!) + :param tar_channels: name of output/target channels (needed for associating losses during compilation) :return: """ + # basic configuration of U-Net + channels_start = hparams_unet["ngf"] + z_branch = hparams_unet["z_branch"] + kernel_pool = hparams_unet["kernel_pool"] + l_avgpool = hparams_unet["l_avgpool"] + l_subpixel = hparams_unet["l_subpixel"] + + config_conv = {"kernel": hparams_unet["kernel"], "strides": hparams_unet["strides"], "padding": hparams_unet["padding"], + "activation": hparams_unet["activation"], "activation_args": hparams_unet["activation_args"], + "kernel_init": hparams_unet["kernel_init"], "l_batch_normalization": hparams_unet["l_batch_normalization"]} + + # build U-Net inputs = Input(input_shape) """ encoder """ - s1, e1 = encoder_block(inputs, channels_start, l_large=True) - s2, e2 = encoder_block(e1, channels_start * 2, l_large=False) - s3, e3 = encoder_block(e2, channels_start * 4, l_large=False) + s1, e1 = encoder_block(inputs, channels_start, l_large=True, kernel_pool=kernel_pool, l_avgpool=l_avgpool,**config_conv) + s2, e2 = encoder_block(e1, channels_start * 2, l_large=False, kernel_pool=kernel_pool, l_avgpool=l_avgpool,**config_conv) + s3, e3 = encoder_block(e2, channels_start * 4, l_large=False, kernel_pool=kernel_pool, l_avgpool=l_avgpool,**config_conv) """ bridge encoder <-> decoder """ - b1 = conv_block(e3, channels_start * 8) + b1 = conv_block(e3, channels_start * 8, **config_conv) """ decoder """ - d1 = decoder_block(b1, s3, channels_start * 4) - d2 = decoder_block(d1, s2, channels_start * 2) - d3 = decoder_block(d2, s1, channels_start) + d1 = decoder_block(b1, s3, channels_start * 4, l_subpixel=l_subpixel, **config_conv) + d2 = decoder_block(d1, s2, channels_start * 2, l_subpixel=l_subpixel, **config_conv) + d3 = decoder_block(d2, s1, channels_start, l_subpixel=l_subpixel, **config_conv) - output_dyn = Conv2D(n_predictands_dyn, (1, 1), kernel_initializer="he_normal", name=tar_channels[0])(d3) + output_dyn = Conv2D(n_predictands_dyn, (1, 1), kernel_initializer=config_conv["kernel_init"], name=tar_channels[0])(d3) if z_branch: print("Use z_branch...") - output_static = Conv2D(1, (1, 1), kernel_initializer="he_normal", name=tar_channels[1])(d3) + output_static = Conv2D(1, (1, 1), kernel_initializer=config_conv["kernel_init"], name=tar_channels[1])(d3) if concat_out: model = Model(inputs, tf.concat([output_dyn, output_static], axis=-1), name="downscaling_unet_with_z") @@ -173,7 +212,7 @@ class UNET(keras.Model): self.hparams = UNET.get_hparams_dict(hparams) self.n_predictands = len(varnames_tar) # number of predictands self.n_predictands_dyn = self.n_predictands - 1 if self.hparams["z_branch"] else self.n_predictands - if self.hparams["l_embed"]: + if self.hparams.get("l_embed", False): raise ValueError("Embedding is not implemented yet.") self.modelname = exp_name if not os.path.isdir(savedir): @@ -186,7 +225,6 @@ class UNET(keras.Model): See https://stackoverflow.com/questions/65318036/is-it-possible-to-use-the-tensorflow-keras-functional-api-train_unet-model-err.6387845within-a-subclassed-mo for a reference how a model based on Keras functional API has to be integrated into a subclass. """ - print(**kwargs) return self.unet(inputs, **kwargs) def get_compile_opts(self): @@ -210,13 +248,14 @@ class UNET(keras.Model): return opt_dict def compile(self, **kwargs): + # instantiate model if self.hparams["z_branch"]: # model has named branches (see also opt_dict in get_compile_opts) tar_channels = [f"{var}" for var in self.varnames_tar] - self.unet = self.unet(self.shape_in, self.n_predictands_dyn, z_branch=True, - concat_out= False, tar_channels=tar_channels) + self.unet = self.unet(self.shape_in, self.n_predictands_dyn, self.hparams, + concat_out=False, tar_channels=tar_channels) else: - self.unet = self.unet(self.shape_in, self.n_predictands_dyn, z_branch=False) + self.unet = self.unet(self.shape_in, self.n_predictands_dyn, self.hparams) return self.unet.compile(**kwargs) # return super(UNET, self).compile(**kwargs) @@ -257,7 +296,9 @@ class UNET(keras.Model): callback_list = callback_list + [LearningRateScheduler(self.get_lr_scheduler(), verbose=1)] if self.hparams["lscheduled_train"]: - callback_list = callback_list + [ModelCheckpoint(self.savedir, monitor="val_loss", verbose=1, + savedir_best = os.path.join(self.savedir, f"{self.modelname}_best") + os.makedirs(savedir_best, exist_ok=True) + callback_list = callback_list + [ModelCheckpoint(savedir_best, monitor="val_t_2m_tar_loss", verbose=1, save_best_only=True, mode="min")] # + EarlyStopping(monitor="val_recon_loss", patience=8)] @@ -279,6 +320,9 @@ class UNET(keras.Model): def save(self, **kwargs): self.unet.save(**kwargs) + def plot_model(self, **kwargs): + self.unet.plot_model(**kwargs) + @staticmethod def get_hparams_dict(hparams_user: dict) -> dict: """ @@ -312,10 +356,13 @@ class UNET(keras.Model): """ Return default hyperparameter dictionary. """ - hparams_dict = {"batch_size": 32, "lr": 5.e-05, "nepochs": 70, "z_branch": True, "loss_func": "mae", - "loss_weights": [1.0, 1.0], "lr_decay": False, "decay_start": 5, "decay_end": 30, - "lr_end": 1.e-06, "l_embed": False, "ngf": 56, "optimizer": "adam", "lscheduled_train": True, - "var_tar2in": "", "n_predictands": 1} + + hparams_dict = {"kernel": (3, 3), "strides": (1, 1), "padding": "same", "activation": "swish", "activation_args": {}, # arguments for building blocks of U-Net: + "kernel_init": "he_normal", "l_batch_normalization": True, "kernel_pool": (2, 2), "l_avgpool": True, # see keyword-aguments of sha_unet, conv_block, + "l_subpixel": True, "z_branch": True, "ngf": 56, # encoder_block and decoder_block + "batch_size": 32, "lr": 5.e-05, "nepochs": 70, "loss_func": "mae", "loss_weights": [1.0, 1.0], # training parameters + "lr_decay": False, "decay_start": 5, "decay_end": 30, "lr_end": 1.e-06, "l_embed": False, + "optimizer": "adam", "lscheduled_train": True} return hparams_dict diff --git a/downscaling_ap5/models/wgan_model.py b/downscaling_ap5/models/wgan_model.py index 1e14df96c255b6fb96798099b3f2ba0176ad3eec..53ea658af545b465b26391fb4cf2db440c39a3b3 100644 --- a/downscaling_ap5/models/wgan_model.py +++ b/downscaling_ap5/models/wgan_model.py @@ -5,7 +5,7 @@ __author__ = "Michael Langguth" __email__ = "m.langguth@fz-juelich.de" __date__ = "2022-05-19" -__update__ = "2022-11-25" +__update__ = "2023-08-08" import os, sys from typing import List, Tuple, Union @@ -23,14 +23,15 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.keras.layers import (Input, Dense, GlobalAveragePooling2D) from tensorflow.keras.models import Model # other modules +from custom_losses import get_custom_loss from unet_model import conv_block -from other_utils import to_list +from other_utils import to_list, merge_dicts list_or_tuple = Union[List, Tuple] def critic_model(shape, num_conv: int = 4, channels_start: int = 64, kernel: tuple = (3, 3), - stride: tuple = (2, 2), activation: str = "relu", lbatch_norm: bool = True): + stride: tuple = (2, 2), activation: str = "swish", lbatch_norm: bool = True): """ Set-up convolutional discriminator model that is followed by two dense-layers :param shape: input shape of data (either real or generated data) @@ -56,7 +57,7 @@ def critic_model(shape, num_conv: int = 4, channels_start: int = 64, kernel: tup # finally perform global average pooling and finalize by fully connected layers x = GlobalAveragePooling2D()(x) - x = Dense(channels_start)(x) + x = Dense(channels_start, activation=activation)(x) # ... and end with linear output layer out = Dense(1, activation="linear")(x) @@ -91,7 +92,7 @@ class WGAN(keras.Model): if self.hparams["l_embed"]: raise ValueError("Embedding is not implemented yet.") self.n_predictands = len(varnames_tar) # number of predictands - self.n_predictands_dyn = self.n_predictands - 1 if self.hparams["z_branch"] else self.n_predictands + self.n_predictands_dyn = self.n_predictands - 1 if self.hparams["hparams_generator"]["z_branch"] else self.n_predictands print(f"Dynamic predictands: {self.n_predictands_dyn}, Predictands: {self.n_predictands}") self.modelname = exp_name if not os.path.isdir(savedir): @@ -100,10 +101,16 @@ class WGAN(keras.Model): # instantiate submodels # instantiate model components (generator and critci) - self.generator = self.generator(self.shape_in, self.n_predictands, channels_start=self.hparams["ngf"], - concat_out=True, z_branch=self.hparams["z_branch"]) + self.generator = self.generator(self.shape_in, self.n_predictands_dyn, self.hparams["hparams_generator"], concat_out=True) tar_shape = (*self.shape_in[:-1], self.n_predictands_dyn) # critic only accounts for dynamic predictands - self.critic = self.critic(tar_shape) + critic_kwargs = self.hparams["hparams_critic"].copy() + critic_kwargs.pop("lr", None) + self.critic = self.critic(tar_shape, **critic_kwargs) + + # losses + self.critic_loss = get_custom_loss("critic") + self.critic_gen_loss = get_custom_loss("critic_generator") + self.recon_loss = self.get_recon_loss() # Unused attribute, but introduced for joint driver script with U-Net; to be solved with customized target vars self.varnames_tar = None @@ -113,6 +120,19 @@ class WGAN(keras.Model): self.lr_scheduler = None self.checkpoint, self.earlystopping = None, None + def get_recon_loss(self): + + kwargs_loss = {} + if "vec" in self.hparams["recon_loss"]: + kwargs_loss = {"nd_vec": self.hparams.get("nd_vec", 2), "n_channels": self.n_predictands} + elif "channels" in self.hparams["recon_loss"]: + kwargs_loss = {"n_channels": self.n_predictands} + + loss_fn = get_custom_loss(self.hparams["recon_loss"], **kwargs_loss) + + return loss_fn + + def compile(self, **kwargs): """ Instantiate generator and critic, compile model and then set optimizer as well optional learning rate decay and @@ -143,7 +163,7 @@ class WGAN(keras.Model): :return: learning rate scheduler """ decay_st, decay_end = self.hparams["decay_start"], self.hparams["decay_end"] - lr_start, lr_end = self.hparams["lr_gen"], self.hparams["lr_gen_end"] + lr_start, lr_end = self.hparams["hparams_generator"]["lr"], self.hparams["hparams_generator"]["lr_end"] if not decay_end > decay_st: raise ValueError("Epoch for end of learning rate decay must be large than start epoch. " + @@ -206,7 +226,7 @@ class WGAN(keras.Model): critic_gen = self.critic(gen_data[..., 0:self.n_predictands_dyn], training=True) critic_gt = self.critic(predictands_critic, training=True) # calculate the loss (incl. gradient penalty) - c_loss = WGAN.critic_loss(critic_gt, critic_gen) + c_loss = self.critic_loss(critic_gt, critic_gen) gp = self.gradient_penalty(predictands_critic, gen_data[..., 0:self.n_predictands_dyn]) d_loss = c_loss + self.hparams["gp_weight"] * gp @@ -221,7 +241,7 @@ class WGAN(keras.Model): gen_data = self.generator(predictors[-self.hparams["batch_size"]:, :, :, :], training=True) # get the critic and calculate corresponding generator losses (critic and reconstruction loss) critic_gen = self.critic(gen_data[..., 0:self.n_predictands_dyn], training=True) - cg_loss = WGAN.critic_gen_loss(critic_gen) + cg_loss = self.critic_gen_loss(critic_gen) rloss = self.recon_loss(predictands[-self.hparams["batch_size"]:, :, :, :], gen_data) g_loss = cg_loss + self.hparams["recon_weight"] * rloss @@ -275,14 +295,6 @@ class WGAN(keras.Model): return gp - def recon_loss(self, real_data, gen_data): - # initialize reconstruction loss - rloss = 0. - # get MAE for all output heads - for i in range(self.hparams["n_predictands"]): - rloss += tf.reduce_mean(tf.abs(tf.squeeze(gen_data[..., i]) - real_data[..., i])) - - return rloss def plot_model(self, save_dir, **kwargs): """ @@ -340,28 +352,22 @@ class WGAN(keras.Model): # check if parsed hyperparameters are known unknown_keys = [key for key in hparams_user.keys() if key not in hparams_default] - if unknown_keys: + if unknown_keys: # if unknown keys are found, remove them from user-defined hyperparameters print("The following parsed hyperparameters are unknown and thus are ignored: {0}".format( ", ".join(unknown_keys))) + [hparams_user.pop(unknown_key) for unknown_key in unknown_keys] + + hparams_dict = merge_dicts(hparams_default, hparams_user) # merge default and user-defined hyperparameters - # get complete hyperparameter dictionary while checking type of parsed values - hparams_merged = {**hparams_default, **hparams_user} - hparams_dict = {} - for key in hparams_default: - if isinstance(hparams_merged[key], type(hparams_default[key])): - hparams_dict[key] = hparams_merged[key] - else: - raise TypeError("Parsed hyperparameter '{0}' must be of type '{1}', but is '{2}'" - .format(key, type(hparams_default[key]), type(hparams_merged[key]))) - + # check if optimizer is valid and set corresponding optimizers for generator and critic if hparams_dict["optimizer"].lower() == "adam": adam = keras.optimizers.Adam - hparams_dict["d_optimizer"] = adam(learning_rate=hparams_dict["lr_critic"], beta_1=0.0, beta_2=0.9) - hparams_dict["g_optimizer"] = adam(learning_rate=hparams_dict["lr_gen"], beta_1=0.0, beta_2=0.9) + hparams_dict["d_optimizer"] = adam(learning_rate=hparams_dict["hparams_critic"]["lr"], beta_1=0.0, beta_2=0.9) + hparams_dict["g_optimizer"] = adam(learning_rate=hparams_dict["hparams_generator"]["lr"], beta_1=0.0, beta_2=0.9) elif hparams_dict["optimizer"].lower() == "rmsprop": rmsprop = keras.optimizers.RMSprop - hparams_dict["d_optimizer"] = rmsprop(lr=hparams_dict["lr_critic"]) # increase beta-values ? - hparams_dict["g_optimizer"] = rmsprop(lr=hparams_dict["lr_gen"]) + hparams_dict["d_optimizer"] = rmsprop(lr=hparams_dict["hparams_critic"]["lr"]) # increase beta-values ? + hparams_dict["g_optimizer"] = rmsprop(lr=hparams_dict["hparams_generator"]["lr"]) else: raise ValueError("'{0}' is not a valid optimizer. Either choose Adam or RMSprop-optimizer") @@ -372,32 +378,18 @@ class WGAN(keras.Model): """ Return default hyperparameter dictionary. """ - hparams_dict = {"batch_size": 32, "lr_gen": 1.e-05, "lr_critic": 1.e-06, "nepochs": 50, "z_branch": False, - "lr_decay": False, "decay_start": 5, "decay_end": 10, "lr_gen_end": 1.e-06, "l_embed": False, - "ngf": 56, "d_steps": 5, "recon_weight": 1000., "gp_weight": 10., "optimizer": "adam", - "lscheduled_train": True, "var_tar2in": "", "n_predictands": 2} - + hparams_dict = {"batch_size": 32, "nepochs": 50, "lr_decay": False, "decay_start": 5, "decay_end": 10, + "l_embed": False, "d_steps": 5, "recon_weight": 1000., "gp_weight": 10., "optimizer": "adam", + "lscheduled_train": True, "recon_loss": "mae_channels", + "hparams_generator": {"kernel": (3, 3), "strides": (1, 1), "padding": "same", "activation": "swish", "activation_args": {}, # arguments for building blocks of U-Net: + "kernel_init": "he_normal", "l_batch_normalization": True, "kernel_pool": (2, 2), "l_avgpool": True, # see keyword-aguments of sha_unet, conv_block, + "l_subpixel": True, "z_branch": True, "ngf": 56, "lr": 1.e-05, "lr_end": 1.e-06}, + "hparams_critic": {"num_conv": 4, "channels_start": 64, "activation": "swish", + "lbatch_norm": True, "kernel": (3, 3), "stride": (2, 2), "lr": 1.e-06,} + } + return hparams_dict - @staticmethod - def critic_loss(critic_real, critic_gen): - """ - The critic is optimized to maximize the difference between the generated and the real data max(real - gen). - This is equivalent to minimizing the negative of this difference, i.e. min(gen - real) = max(real - gen) - :param critic_real: critic on the real data - :param critic_gen: critic on the generated data - :return c_loss: loss to optize the critic - """ - c_loss = tf.reduce_mean(critic_gen - critic_real) - - return c_loss - - @staticmethod - def critic_gen_loss(critic_gen): - cg_loss = -tf.reduce_mean(critic_gen) - - return cg_loss - class LearningRateSchedulerWGAN(LearningRateScheduler): diff --git a/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_gradient_ratio.ipynb b/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_gradient_ratio.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1ea138b3d136b972b2330dd3a1949491e3baa5c8 --- /dev/null +++ b/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_gradient_ratio.ipynb @@ -0,0 +1,183 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "6d799556-9628-4ed8-acfa-5a8f31190e5c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import os\n", + "import numpy as np\n", + "import xarray as xr\n", + "import pandas as pd\n", + "\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a64d03c-882d-4a07-9fcc-2a15adbb10fa", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# auxiliary methods\n", + "\n", + "def create_lines_plot(data: xr.DataArray, data_std: xr.DataArray, model_names: str, metric: dict,\n", + " plt_fname: str, x_coord: str = \"hour\", **kwargs):\n", + "\n", + " # get some plot parameters\n", + " linestyle = kwargs.get(\"linestyle\", [\"k-\", \"b-\"])\n", + " err_col = kwargs.get(\"error_color\", [\"grey\", \"blue\"])\n", + " val_range = kwargs.get(\"value_range\", (0.7, 1.1))\n", + " fs = kwargs.get(\"fs\", 16)\n", + " ref_line = kwargs.get(\"ref_line\", None)\n", + " ref_linestyle = kwargs.get(\"ref_linestyle\", \"k--\")\n", + " \n", + " fig, (ax) = plt.subplots(1, 1, figsize=(8, 6))\n", + " for i, exp in enumerate(data[\"exp\"]):\n", + " ax.plot(data[x_coord].values, data.sel({\"exp\": exp}).values, linestyle[i],\n", + " label=model_names[i])\n", + " ax.fill_between(data[x_coord].values, data.sel({\"exp\": exp}).values-data_std.sel({\"exp\": exp}).values,\n", + " data.sel({\"exp\": exp}).values+data_std.sel({\"exp\": exp}).values, facecolor=err_col[i],\n", + " alpha=0.2)\n", + " if ref_line is not None:\n", + " nval = np.shape(data[x_coord].values)[0]\n", + " ax.plot(data[x_coord].values, np.full(nval, ref_line), ref_linestyle)\n", + " ax.set_ylim(*val_range)\n", + " ax.set_yticks(np.arange(*val_range, 0.05))\n", + " # label axis\n", + " ax.set_xlabel(\"daytime [UTC]\", fontsize=fs)\n", + " metric_name, metric_unit = list(metric.keys())[0], list(metric.values())[0]\n", + " ax.set_ylabel(f\"{metric_name} T2m [{metric_unit}]\", fontsize=fs)\n", + " ax.tick_params(axis=\"both\", which=\"both\", direction=\"out\", labelsize=fs-2)\n", + " ax.legend(fontsize=fs-2, loc=\"upper right\")\n", + "\n", + " # save plot and close figure\n", + " plt_fname = plt_fname + \".png\" if not plt_fname.endswith(\".png\") else plt_fname\n", + " print(f\"Save plot in file '{plt_fname}'\")\n", + " #plt.tight_layout()\n", + " #fig.savefig(plt_fname)\n", + " fig.savefig(plt_fname, bbox_inches=\"tight\")\n", + " plt.close(fig)\n", + "\n", + "def get_id_from_fname(fname):\n", + " try:\n", + " start_index = fname.find(\"id\") + 2 # Adding 2 to move past \"id\"\n", + " end_index = fname.find(\"_\", start_index)\n", + " \n", + " exp_id = fname[start_index:end_index]\n", + " except:\n", + " raise ValueError(f\"Failed to deduce experiment ID from '{fname}'\")\n", + " \n", + " return exp_id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8470099-80b7-46b5-bb68-933d5549ec0b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# parameters\n", + "results_basedir = \"/p/home/jusers/langguth1/juwels/downscaling_maelstrom/downscaling_jsc_repo/downscaling_ap5/results\"\n", + "plt_dir = os.path.join(results_basedir, \"meta\")\n", + "\n", + "exp1 = \"wgan_t2m_atmorep_test\"\n", + "exp2 = \"atmorep_id26n32cey\"\n", + "\n", + "varname = \"T2m\"\n", + "year = 2018" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c05778d6-0835-40b9-8010-149c05e00519", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# main\n", + "os.makedirs(plt_dir, exist_ok=True)\n", + "\n", + "fexp1 = os.path.join(results_basedir, exp1, \"metric_files\", \"eval_grad_amplitude_year.csv\")\n", + "fexp2 = os.path.join(results_basedir, exp2, \"metric_files\", \"eval_grad_amplitude__small_dom_year.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3634bcd3-40f7-4497-9561-3f41c00fc039", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "dims = [\"hour\", \"type\"]\n", + "coord_dict = {\"hour\": np.arange(24), \"type\": [\"mean\", \"std\"]}\n", + "\n", + "da_gr_exp1 = xr.DataArray(pd.read_csv(fexp1, header=0, index_col=0), dims=dims, coords=coord_dict)\n", + "da_gr_exp2 = xr.DataArray(pd.read_csv(fexp2, header=0, index_col=0), dims=dims, coords=coord_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "492e8f95-e00f-43f6-951e-72e9a2a60d35", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "da_gr_all = xr.concat([da_gr_exp1, da_gr_exp2], dim= \"exp\")\n", + "da_gr_all = da_gr_all.assign_coords({\"exp\": [exp1, exp2]})\n", + "\n", + "# create plot\n", + "plt_fname = os.path.join(plt_dir, f\"eval_grad_amplitude_{exp1}_{exp2}.png\")\n", + "create_lines_plot(da_gr_all.sel({\"type\": \"mean\"}), da_gr_all.sel({\"type\": \"std\"}),\n", + " [\"WGAN\", \"AtmoRep\"], {\"GRAD_AMPLITUDE\": \"1\"}, plt_fname, re)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0bb44d2-85eb-4bd8-9922-ba3ce5ffd38d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langguth1_downscaling_kernel_juwels", + "language": "python", + "name": "langguth1_downscaling_kernel_juwels" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_rmse.ipynb b/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_rmse.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..8b68161292e052e37b22bf578cd7d48a6e593676 --- /dev/null +++ b/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_rmse.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "b36a047f-7db7-4997-9959-3b4767f43345", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import os\n", + "import numpy as np\n", + "import xarray as xr\n", + "import pandas as pd\n", + "\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae48851c-7d59-4128-af9d-97795e610aea", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# auxiliary methods\n", + "\n", + "def create_lines_plot(data: xr.DataArray, data_std: xr.DataArray, model_names: str, metric: dict,\n", + " plt_fname: str, x_coord: str = \"hour\", **kwargs):\n", + "\n", + " # get some plot parameters\n", + " linestyle = kwargs.get(\"linestyle\", [\"k-\", \"b-\"])\n", + " err_col = kwargs.get(\"error_color\", [\"grey\", \"blue\"])\n", + " val_range = kwargs.get(\"value_range\", (0., 3.))\n", + " fs = kwargs.get(\"fs\", 16)\n", + " ref_line = kwargs.get(\"ref_line\", None)\n", + " ref_linestyle = kwargs.get(\"ref_linestyle\", \"k--\")\n", + " \n", + " fig, (ax) = plt.subplots(1, 1)\n", + " for i, exp in enumerate(data[\"exp\"]):\n", + " ax.plot(data[x_coord].values, data.sel({\"exp\": exp}).values, linestyle[i],\n", + " label=model_names[i])\n", + " ax.fill_between(data[x_coord].values, data.sel({\"exp\": exp}).values-data_std.sel({\"exp\": exp}).values,\n", + " data.sel({\"exp\": exp}).values+data_std.sel({\"exp\": exp}).values, facecolor=err_col[i],\n", + " alpha=0.2)\n", + " if ref_line is not None:\n", + " nval = np.shape(data[x_coord].values)[0]\n", + " ax.plot(data[x_coord].values, np.full(nval, ref_line), ref_linestyle)\n", + " ax.set_ylim(*val_range)\n", + " # label axis\n", + " ax.set_xlabel(\"daytime [UTC]\", fontsize=fs)\n", + " metric_name, metric_unit = list(metric.keys())[0], list(metric.values())[0]\n", + " ax.set_ylabel(f\"{metric_name} T2m [{metric_unit}]\", fontsize=fs)\n", + " ax.tick_params(axis=\"both\", which=\"both\", direction=\"out\", labelsize=fs-2)\n", + " ax.legend(fontsize=fs-2, loc=\"upper right\")\n", + "\n", + " # save plot and close figure\n", + " plt_fname = plt_fname + \".png\" if not plt_fname.endswith(\".png\") else plt_fname\n", + " print(f\"Save plot in file '{plt_fname}'\")\n", + " plt.tight_layout()\n", + " fig.savefig(plt_fname)\n", + " plt.close(fig)\n", + "\n", + "def get_id_from_fname(fname):\n", + " try:\n", + " start_index = fname.find(\"id\") + 2 # Adding 2 to move past \"id\"\n", + " end_index = fname.find(\"_\", start_index)\n", + " \n", + " exp_id = fname[start_index:end_index]\n", + " except:\n", + " raise ValueError(f\"Failed to deduce experiment ID from '{fname}'\")\n", + " \n", + " return exp_id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d57de46-706f-4386-8e14-1263a8c96666", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# parameters\n", + "results_basedir = \"/p/home/jusers/langguth1/juwels/downscaling_maelstrom/downscaling_jsc_repo/downscaling_ap5/results\"\n", + "plt_dir = os.path.join(results_basedir, \"meta\")\n", + "\n", + "exp1 = \"wgan_t2m_atmorep_test\"\n", + "exp2 = \"atmorep_id26n32cey\"\n", + "\n", + "varname = \"T2m\"\n", + "year = 2018" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c2327f6-85e9-4bc4-acb6-21b110abcdc9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# main\n", + "os.makedirs(plt_dir, exist_ok=True)\n", + "\n", + "fexp1 = os.path.join(results_basedir, exp1, \"metric_files\", \"eval_rmse_year.csv\")\n", + "fexp2 = os.path.join(results_basedir, exp2, \"metric_files\", \"eval_rmse__small_dom_year.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d5c40cd-586d-46b8-8069-2a14a6ab7eee", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "dims = [\"hour\", \"type\"]\n", + "coord_dict = {\"hour\": np.arange(24), \"type\": [\"mean\", \"std\"]}\n", + "\n", + "da_rmse_exp1 = xr.DataArray(pd.read_csv(fexp1, header=0, index_col=0), dims=dims, coords=coord_dict)\n", + "da_rmse_exp2 = xr.DataArray(pd.read_csv(fexp2, header=0, index_col=0), dims=dims, coords=coord_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca2993fb-f008-4ef4-bcff-b3a1745116ff", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "da_rmse_all = xr.concat([da_rmse_exp1, da_rmse_exp2], dim= \"exp\")\n", + "da_rmse_all = da_rmse_all.assign_coords({\"exp\": [exp1, exp2]})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fd5217b-c59d-4979-81e7-790ff810af04", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "plt_fname = os.path.join(plt_dir, f\"eval_rmse_{exp1}_{exp2}.png\")\n", + "create_lines_plot(da_rmse_all.sel({\"type\": \"mean\"}), da_rmse_all.sel({\"type\": \"std\"}),\n", + " [\"WGAN\", \"AtmoRep\"], {\"RMSE\": \"K\"}, plt_fname)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70d7870d-25f1-49de-a646-b1d062edc474", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "PyEarthSystem-2023.5", + "language": "python", + "name": "pyearthsystem" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/downscaling_ap5/postprocess/postprocess_wgan_era5_ifs.ipynb b/downscaling_ap5/postprocess/jupyter_notebooks/postprocess_wgan_era5_ifs.ipynb similarity index 99% rename from downscaling_ap5/postprocess/postprocess_wgan_era5_ifs.ipynb rename to downscaling_ap5/postprocess/jupyter_notebooks/postprocess_wgan_era5_ifs.ipynb index 6d42a70f7986cddb491eb2b43683e56f27c9e048..3b8a6ff2ce63be53959752f8f8b083689c7bce74 100644 --- a/downscaling_ap5/postprocess/postprocess_wgan_era5_ifs.ipynb +++ b/downscaling_ap5/postprocess/jupyter_notebooks/postprocess_wgan_era5_ifs.ipynb @@ -415,9 +415,9 @@ ], "metadata": { "kernelspec": { - "display_name": "langguth1_downscaling_kernel_juwels", + "display_name": "langguth1_downscaling_kernel", "language": "python", - "name": "langguth1_downscaling_kernel_juwels" + "name": "langguth1_downscaling_kernel" }, "language_info": { "codemirror_mode": { diff --git a/downscaling_ap5/postprocess/plotting.py b/downscaling_ap5/postprocess/plotting.py index cb706c2dbdaf33bcf18a89186403eac3cef0a80f..ce09863fa5da0ca0b350d6066e6e4430b22daa4e 100644 --- a/downscaling_ap5/postprocess/plotting.py +++ b/downscaling_ap5/postprocess/plotting.py @@ -9,7 +9,7 @@ Methods for creating plots. __author__ = "Michael Langguth" __email__ = "m.langguth@fz-juelich.de" __date__ = "2022-01-20" -__update__ = "2022-12-08" +__update__ = "2023-12-10" # for processing data import os @@ -26,6 +26,9 @@ import cartopy.crs as ccrs from other_utils import provide_default # auxiliary variable for logger +# auxiliary variable for logger +logger_module_name = f"main_postprocess.{__name__}" +module_logger = logging.getLogger(logger_module_name) module_name = os.path.basename(__file__).rstrip(".py") @@ -50,7 +53,7 @@ def get_colormap_temp(levels=None): # for making plot nice -def decorate_plot(ax_plot, plot_xlabel=True, plot_ylabel=True): +def decorate_plot(ax_plot, plot_xlabel=True, plot_ylabel=True, extent=[3.5, 16.5, 44.5, 54.]): fs = 16 # if "login" in host: # add nice coast- and borderlines @@ -62,7 +65,7 @@ def decorate_plot(ax_plot, plot_xlabel=True, plot_ylabel=True): ax_plot.set_xticks(np.arange(0., 360. + 0.1, 5.)) # ,crs=projection_crs) ax_plot.set_yticks(np.arange(-90., 90. + 0.1, 5.)) # ,crs=projection_crs) - ax_plot.set_extent([3.5, 16.5, 44.5, 54.]) # , crs=prj_crs) + ax_plot.set_extent(extent) # , crs=prj_crs) ax_plot.minorticks_on() ax_plot.tick_params(axis="both", which="both", direction="out", labelsize=fs) @@ -127,7 +130,7 @@ def create_map_score(score, plt_fname, **kwargs): func_logger = logging.getLogger(f"postprocess.{module_name}.{create_map_score.__name__}") # get keywor arguments - score_dims = kwargs.get("score_dims", ["lat", "lon"]) + dims = kwargs.get("dims", ["lat", "lon"]) title = kwargs.get("title", "Score") levels = kwargs.get("levels", np.arange(-5., 5., 0.5)) # auxiliary variables @@ -135,17 +138,22 @@ def create_map_score(score, plt_fname, **kwargs): nbounds = len(lvl) cmap = kwargs.get("cmap", mpl.cm.PuOr_r(np.linspace(0., 1., nbounds))) fs = kwargs.get("fs", 16) - projection = kwargs.get("projection", ccrs.PlateCarree()) + projection = kwargs.get("projection", ccrs.RotatedPole(pole_longitude=-162.0, pole_latitude=39.25)) + extent = kwargs.get("extent", None) + + decorate_dict = {} + if extent: + decorate_dict["extent"] = extent # get coordinate data try: - lat, lon = score[score_dims[0]].values, score[score_dims[1]].values + lat, lon = score[dims[0]].values, score[dims[1]].values except Exception as err: print("Failed to retrieve coordinates from score-data") raise err # construct array for edges of grid points - dy, dx = np.round((lat[1] - lat[0]), 3), np.round((lon[1] - lon[0]), 3) - lat_e, lon_e = np.arange(lat[0]-dy/2, lat[-1]+dy, dy), np.arange(lon[0]-dx/2, lon[-1]+dx, dx) + dy, dx = np.round((lat[1] - lat[0]), 4), np.round((lon[1] - lon[0]), 4) + lat_e, lon_e = np.arange(lat[0]-dy/2, lat[-1]+dy, dy), np.arange(lon[0]-dx/2, lon[-1]+dx, dx) # get colormap # create colormap and corresponding norm @@ -158,7 +166,7 @@ def create_map_score(score, plt_fname, **kwargs): plt1 = ax.pcolormesh(lon_e, lat_e, np.squeeze(score.values), cmap=cmap_obj, norm=norm, transform=projection) - ax = decorate_plot(ax) + ax = decorate_plot(ax, **decorate_dict) ax.set_title(title, size=fs) @@ -208,3 +216,66 @@ def create_line_plot(data: xr.DataArray, data_std: xr.DataArray, model_name: str plt.tight_layout() fig.savefig(plt_fname) plt.close(fig) + + +# write the create_box_plot function +def create_box_plot(data, plt_fname: str, **plt_kwargs): + """ + Create box plot of feature importance scores + :param feature_scores: Feature importance scores with predictors as firstdimension and time as second dimension + :param plt_fname: File name of plot + """ + func_logger = logging.getLogger(f"postprocess.{module_name}.{create_box_plot.__name__}") + + # get some plot parameters + val_range = plt_kwargs.get("value_range", [None]) + widths = plt_kwargs.get("widths", None) + colors = plt_kwargs.get("colors", None) + fs = plt_kwargs.get("fs", 16) + ref_line = plt_kwargs.get("ref_line", 1.) + ref_linestyle = plt_kwargs.get("ref_linestyle", "k-") + title = plt_kwargs.get("title", "") + ylabel = plt_kwargs.get("ylabel", "") + xlabel = plt_kwargs.get("xlabel", "") + yticks = plt_kwargs.get("yticks", None) + labels = plt_kwargs.get("labels", None) + + # create box whiskers plot with matplotlib + fig, ax = plt.subplots(figsize=(12, 8)) + + bp = plt.boxplot(data, widths=widths, labels=labels, patch_artist=True) + + # modify fliers + fliers = bp['fliers'] + for i in range(len(fliers)): # iterate through the Line2D objects for the fliers for each boxplot + box = fliers[i] # this accesses the x and y vectors for the fliers for each box + box.set_data([[box.get_xdata()[0]],[np.max(box.get_ydata())]]) + + if ref_line is not None: + nval = len(fliers) + ax.plot(np.array(range(0, nval+1)) + 0.5, np.full(nval+1, ref_line), ref_linestyle) + + if colors is None: + pass + else: + if isinstance(colors, str): colors = len(bp["boxes"])*[colors] + for patch, color in zip(bp['boxes'], colors): + patch.set_facecolor(color) + + ax.set_ylim(*val_range) + ax.set_yticks(yticks) + + ax.set_title(title, fontsize=fs + 2) + ax.set_ylabel(ylabel, fontsize=fs, labelpad=8) + ax.set_xlabel(xlabel, fontsize=fs, labelpad=8) + ax.tick_params(axis="both", which="both", direction="out", labelsize=fs-2) + ax.yaxis.grid(True) + + # save plot + plt.tight_layout() + plt.savefig(plt_fname + ".png" if not plt_fname.endswith(".png") else plt_fname) + plt.close(fig) + + func_logger.info(f"Feature importance scores saved to {plt_fname}.") + + return True diff --git a/downscaling_ap5/postprocess/postprocess.py b/downscaling_ap5/postprocess/postprocess.py index bac9c6ab6ce2ca33c88344537ed499a5d8fb883c..3b6b74d751404c5e0753b055f7bff845845206d8 100644 --- a/downscaling_ap5/postprocess/postprocess.py +++ b/downscaling_ap5/postprocess/postprocess.py @@ -9,13 +9,20 @@ Auxiliary methods for postprocessing. __author__ = "Michael Langguth" __email__ = "m.langguth@fz-juelich.de" __date__ = "2022-12-08" -__update__ = "2022-12-08" +__update__ = "2023-10-12" import os +from typing import Union, List import logging +import numpy as np import xarray as xr from cartopy import crs -from plotting import create_line_plot, create_map_score +from statistical_evaluation import feature_importance +from plotting import create_line_plot, create_map_score, create_box_plot + +# basic data types +da_or_ds = Union[xr.DataArray, xr.Dataset] +list_or_str = Union[List[str], str] # auxiliary variable for logger logger_module_name = f"main_postprocess.{__name__}" @@ -28,18 +35,19 @@ def get_model_info(model_base, output_base: str, exp_name: str, bool_last: bool func_logger = logging.getLogger(f"{logger_module_name}.{get_model_info.__name__}") model_name = os.path.basename(model_base) + norm_dir = model_base + + + add_str = "_last" if bool_last else "_best" if "wgan" in exp_name: func_logger.debug(f"WGAN-modeltype detected.") - add_str = "_last" if bool_last else "" model_dir, plt_dir = os.path.join(model_base, f"{exp_name}_generator{add_str}"), \ os.path.join(output_base, model_name) - norm_dir = model_base model_type = "wgan" elif "unet" in exp_name or "deepru" in exp_name: func_logger.debug(f"U-Net-modeltype detected.") - model_dir, plt_dir = model_base, os.path.join(output_base, model_name) - norm_dir = model_dir + model_dir, plt_dir = os.path.join(model_base, f"{exp_name}{add_str}"), os.path.join(output_base, model_name) model_type = "unet" if "unet" in exp_name else "deepru" else: func_logger.debug(f"Model type could not be inferred from experiment name. Try my best by defaulting...") @@ -53,6 +61,29 @@ def get_model_info(model_base, output_base: str, exp_name: str, bool_last: bool return model_dir, plt_dir, norm_dir, model_type +def run_feature_importance(da: xr.DataArray, predictors: list_or_str, varname_tar: str, model, norm, score_name: str, + ref_score: float, data_loader_opt: dict, plt_dir: str, patch_size = (6, 6), variable_dim = "variable"): + + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{run_feature_importance.__name__}") + + # get feature importance scores + feature_scores = feature_importance(da, predictors, varname_tar, model, norm, score_name, data_loader_opt, + patch_size=patch_size, variable_dim=variable_dim) + + rel_changes = feature_scores / ref_score + max_rel_change = int(np.ceil(np.amax(rel_changes) + 1.)) + + # plot feature importance scores in a box-plot with whiskers where each variable is a box + plt_fname = os.path.join(plt_dir, f"feature_importance_{score_name}.png") + + create_box_plot(rel_changes.T, plt_fname, **{"title": f"Feature Importance ({score_name.upper()})", "ref_line": 1., "widths": .3, + "xlabel": "Predictors", "ylabel": f"Rel. change {score_name.upper()}", "labels": predictors, + "yticks": range(1, max_rel_change), "colors": "b"}) + + return feature_scores + + def run_evaluation_time(score_engine, score_name: str, score_unit: str, plot_dir: str, **plt_kwargs): """ Create line plots of desired evaluation metric. Evaluation metric must have a time-dimension @@ -64,79 +95,102 @@ def run_evaluation_time(score_engine, score_name: str, score_unit: str, plot_dir # get local logger func_logger = logging.getLogger(f"{logger_module_name}.{run_evaluation_time.__name__}") + # create output-directories if necessary + metric_dir = os.path.join(plot_dir, "metric_files") os.makedirs(plot_dir, exist_ok=True) + os.makedirs(metric_dir, exist_ok=True) + model_type = plt_kwargs.get("model_type", "wgan") func_logger.info(f"Start evaluation in terms of {score_name}") score_all = score_engine(score_name) + score_all = score_all.drop_vars("variables") func_logger.info(f"Globally averaged {score_name}: {score_all.mean().values:.4f} {score_unit}, " + - f"standard deviation: {score_all.std().values:.4f}") - + f"standard deviation: {score_all.std().values:.4f}") + score_hourly_all = score_all.groupby("time.hour") score_hourly_mean, score_hourly_std = score_hourly_all.mean(), score_hourly_all.std() - for hh in range(24): - func_logger.debug(f"Evaluation for {hh:02d} UTC") - if hh == 0: - tmp = score_all.isel({"time": score_all.time.dt.hour == hh}).groupby("time.season") - score_hourly_mean_sea, score_hourly_std_sea = tmp.mean().copy(), tmp.std().copy() - else: - tmp = score_all.isel({"time": score_all.time.dt.hour == hh}).groupby("time.season") - score_hourly_mean_sea, score_hourly_std_sea = xr.concat([score_hourly_mean_sea, tmp.mean()], dim="hour"), \ - xr.concat([score_hourly_std_sea, tmp.std()], dim="hour") # create plots create_line_plot(score_hourly_mean, score_hourly_std, model_type.upper(), {score_name.upper(): score_unit}, os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}.png"), **plt_kwargs) - for sea in score_hourly_mean_sea["season"]: + scores_to_csv(score_hourly_mean, score_hourly_std, score_name, fname=os.path.join(metric_dir, f"eval_{score_name}_year.csv")) + + score_seas = score_all.groupby("time.season") + for sea, score_sea in score_seas: + score_sea_hh = score_sea.groupby("time.hour") + score_sea_hh_mean, score_sea_hh_std = score_sea_hh.mean(), score_sea_hh.std() func_logger.debug(f"Evaluation for season '{sea}'...") - create_line_plot(score_hourly_mean_sea.sel({"season": sea}), - score_hourly_std_sea.sel({"season": sea}), + create_line_plot(score_sea_hh_mean, + score_sea_hh.std(), model_type.upper(), {score_name.upper(): score_unit}, - os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_{sea.values}.png"), + os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_{sea}.png"), **plt_kwargs) - return True + + scores_to_csv(score_sea_hh_mean, score_sea_hh_std, score_name, + fname=os.path.join(metric_dir, f"eval_{score_name}_{sea}.csv")) + return score_all -def run_evaluation_spatial(score_engine, score_name: str, plot_dir: str, **plt_kwargs): +def run_evaluation_spatial(score_engine, score_name: str, plot_dir: str, + dims = ["rlat", "rlon"], **plt_kwargs): """ Create map plots of desired evaluation metric. Evaluation metric must be given in rotated coordinates. - To-Do: Add flexibility regarding underlying coordinate data (i.e. projection). :param score_engine: Score engine object to comput evaluation metric - :param score_name: Name of evaluation metric (must be implemented into score_engine) :param plot_dir: Directory to save plot files + :param dims: Spatial dimension names """ # get local logger - func_logger = logging.getLogger(f"{logger_module_name}.{run_evaluation_spatial.__name__}") - + func_logger = logging.getLogger(f"{logger_module_name}.{run_evaluation_time.__name__}") + os.makedirs(plot_dir, exist_ok=True) model_type = plt_kwargs.get("model_type", "wgan") score_all = score_engine(score_name) - cosmo_prj = crs.RotatedPole(pole_longitude=-162.0, pole_latitude=39.25) + score_all = score_all.drop_vars("variables") score_mean = score_all.mean(dim="time") fname = os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_avg_map.png") - create_map_score(score_mean, fname, score_dims=["rlat", "rlon"], - title=f"{score_name.upper()} (avg.)", projection=cosmo_prj, **plt_kwargs) + create_map_score(score_mean, fname, dims=dims, + title=f"{score_name.upper()} (avg.)", **plt_kwargs) score_hourly_mean = score_all.groupby("time.hour").mean(dim=["time"]) for hh in range(24): func_logger.debug(f"Evaluation for {hh:02d} UTC") fname = os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_{hh:02d}_map.png") create_map_score(score_hourly_mean.sel({"hour": hh}), fname, - score_dims=["rlat", "rlon"], title=f"{score_name.upper()} {hh:02d} UTC", - projection=cosmo_prj, **plt_kwargs) + dims=dims, title=f"{score_name.upper()} {hh:02d} UTC", + **plt_kwargs) for hh in range(24): score_now = score_all.isel({"time": score_all.time.dt.hour == hh}).groupby("time.season").mean(dim="time") for sea in score_now["season"]: - func_logger.debug(f"Evaluation for season '{str(sea)}' at {hh:02d} UTC") + func_logger.debug(f"Evaluation for season '{str(sea.values)}' at {hh:02d} UTC") fname = os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_{sea.values}_{hh:02d}_map.png") - create_map_score(score_now.sel({"season": sea}), fname, score_dims=["rlat", "rlon"], - title=f"{score_name} {sea.values} {hh:02d} UTC", projection=cosmo_prj, **plt_kwargs) + create_map_score(score_now.sel({"season": sea}), fname, dims=dims, + title=f"{score_name} {sea.values} {hh:02d} UTC", **plt_kwargs) return True + + +def scores_to_csv(score_mean, score_std, score_name, fname="scores.csv"): + """ + Save scores to csv file + :param score_mean: Hourly mean of score + :param score_std: Hourly standard deviation of score + :param score_name: Name of score + :param fname: Filename of csv file + """ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{scores_to_csv.__name__}") + + df_mean = score_mean.to_dataframe(name=f"{score_name}_mean") + df_std = score_std.to_dataframe(name=f"{score_name}_std") + df = df_mean.join(df_std) + + func_logger.info(f"Save values of {score_name} to {fname}...") + df.to_csv(fname) diff --git a/downscaling_ap5/postprocess/statistical_evaluation.py b/downscaling_ap5/postprocess/statistical_evaluation.py index 1c767e2d312221ff7b74f207a1a54a5c18465f00..a555fc5e8ce0bfc88659b16da4413c9d75205bbe 100644 --- a/downscaling_ap5/postprocess/statistical_evaluation.py +++ b/downscaling_ap5/postprocess/statistical_evaluation.py @@ -8,22 +8,31 @@ Collection of auxiliary functions for statistical evaluation and class for Score __email__ = "m.langguth@fz-juelich.de" __author__ = "Michael Langguth" -__date__ = "2022-09-11" +__date__ = "2023-10-10" -import numpy as np -import xarray as xr from typing import Union, List -import datetime -import pandas as pd try: from tqdm import tqdm l_tqdm = True except: l_tqdm = False +import logging +import numpy as np +import pandas as pd +import xarray as xr +from skimage.util.shape import view_as_blocks +from handle_data_class import HandleDataClass +from model_utils import convert_to_xarray from other_utils import provide_default, check_str_in_list + # basic data types da_or_ds = Union[xr.DataArray, xr.Dataset] +list_or_str = Union[List[str], str] + +# auxiliary variable for logger +logger_module_name = f"main_postprocess.{__name__}" +module_logger = logging.getLogger(logger_module_name) def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, factorization="calibration_refinement", @@ -36,21 +45,30 @@ def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, fa :param quantiles: conditional quantiles :return quantile_panel: conditional quantiles of p(m|o) or p(o|m) """ - method = calculate_cond_quantiles.__name__ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{calculate_cond_quantiles.__name__}") # sanity checks if not isinstance(data_fcst, xr.DataArray): - raise ValueError("%{0}: data_fcst must be a DataArray.".format(method)) + err_mess = f"data_fcst must be a DataArray, but is of type '{type(data_fcst)}'." + func_logger.error(err_mess, stack_info=True, exc_info=True) + raise ValueError(err_mess) if not isinstance(data_ref, xr.DataArray): - raise ValueError("%{0}: data_ref must be a DataArray.".format(method)) + err_mess = f"data_ref must be a DataArray, but is of type '{type(data_ref)}'." + func_logger.error(err_mess, stack_info=True, exc_info=True) + raise ValueError(err_mess) if not (list(data_fcst.coords) == list(data_ref.coords) and list(data_fcst.dims) == list(data_ref.dims)): - raise ValueError("%{0}: Coordinates and dimensions of data_fcst and data_ref must be the same".format(method)) + err_mess = f"Coordinates and dimensions of data_fcst and data_ref must be the same." + func_logger.error(err_mess, stack_info=True, exc_info=True) + raise ValueError(err_mess) nquantiles = len(quantiles) if not nquantiles >= 3: - raise ValueError("%{0}: quantiles must be a list/tuple of at least three float values ([0..1])".format(method)) + err_mess = f"Quantiles must be a list/tuple of at least three float values ([0..1])." + func_logger.error(err_mess, stack_info=True, exc_info=True) + raise ValueError(err_mess) if factorization == "calibration_refinement": data_cond = data_fcst @@ -59,8 +77,9 @@ def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, fa data_cond = data_ref data_tar = data_fcst else: - raise ValueError("%{0}: Choose either 'calibration_refinement' or 'likelihood-base_rate' for factorization" - .format(method)) + err_mess = f"Choose either 'calibration_refinement' or 'likelihood-base_rate' for factorization" + func_logger.error(err_mess, stack_info=True, exc_info=True) + raise ValueError(err_mess) # get and set some basic attributes data_cond_longname = provide_default(data_cond.attrs, "longname", "conditioning_variable") @@ -87,7 +106,7 @@ def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, fa attrs={"cond_var_name": data_cond_longname, "cond_var_unit": data_cond_unit, "tar_var_name": data_tar_longname, "tar_var_unit": data_tar_unit}) - print("%{0}: Start caclulating conditional quantiles for all {1:d} bins.".format(method, nbins)) + func_logger.info(f"Start caclulating conditional quantiles for all {nbins:d} bins.") # fill the quantile data array for i in np.arange(nbins): # conditioning of ground truth based on forecast @@ -97,6 +116,37 @@ def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, fa return quantile_panel, data_cond +def get_cdf_of_x(sample_in, prob_in): + """ + Wrappper for interpolating CDF-value for given data + :param sample_in : input values to derive discrete CDF + :param prob_in : corresponding CDF + :return: lambda function converting arbitrary input values to corresponding CDF value + """ + return lambda xin: np.interp(xin, sample_in, prob_in) + +def get_seeps_matrix(seeps_param): + """ + Converts SEEPS paramter array to SEEPS matrix. + :param seeps_param: Array providing p1 and p3 parameters of SEEPS weighting matrix. + :return seeps_matrix: 3x3 weighting matrix for the SEEPS-score + """ + # initialize matrix + seeps_weights = xr.full_like(seeps_param["p1"], np.nan) + seeps_weights = seeps_weights.expand_dims(dim={"weights":np.arange(9)}, axis=0).copy() + seeps_weights.name = "SEEPS weighting matrix" + + # off-diagonal elements + seeps_weights[{"weights": 1}] = 1./(1. - seeps_param["p1"]) + seeps_weights[{"weights": 2}] = 1./seeps_param["p3"] + 1./(1. - seeps_param["p1"]) + seeps_weights[{"weights": 3}] = 1./seeps_param["p1"] + seeps_weights[{"weights": 5}] = 1./seeps_param["p3"] + seeps_weights[{"weights": 6}] = 1./seeps_param["p1"] + 1./(1. - seeps_param["p3"]) + seeps_weights[{"weights": 7}] = 1./(1. - seeps_param["p3"]) + # diagnol elements + seeps_weights[{"weights": [0, 4, 8]}] = xr.where(np.isnan(seeps_weights[{"weights": 7}]), np.nan, 0.) + + return seeps_weights def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length: int, nboots_block: int = 1000, seed: int = 42): @@ -109,14 +159,18 @@ def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length :param seed: seed for random block sampling (to be held constant for reproducability) :return: bootstrapped version of metric(-s) """ - - method = perform_block_bootstrap_metric.__name__ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{perform_block_bootstrap_metric.__name__}") if not isinstance(metric, da_or_ds.__args__): - raise ValueError("%{0}: Input metric must be a xarray DataArray or Dataset and not {1}".format(method, - type(metric))) + err_mess = f"Input metric must be a xarray DataArray or Dataset and not {type(metric)}." + func_logger.error(err_mess, stack_info=True, exc_info=True) + raise ValueError(err_mess) + if dim_name not in metric.dims: - raise ValueError("%{0}: Passed dimension cannot be found in passed metric.".format(method)) + err_mess = f"Passed dimension {dim_name} cannot be found in passed metric." + func_logger.error(err_mess, stack_info=True, exc_info=True) + raise ValueError(err_mess) metric = metric.sortby(dim_name) @@ -124,8 +178,9 @@ def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length nblocks = int(np.floor(dim_length/block_length)) if nblocks < 10: - raise ValueError("%{0}: Less than 10 blocks are present with given block length {1:d}." - .format(method, block_length) + " Too less for bootstrapping.") + err_mess = f"Less than 10 blocks are present with given block length {block_length:d}. Too less for bootstrapping." + func_logger.error(err_mess, stack_info=True, exc_info=True) + raise ValueError(err_mess) # precompute metrics of block for iblock in np.arange(nblocks): @@ -143,7 +198,7 @@ def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length np.random.seed(seed) iblocks_boot = np.sort(np.random.randint(nblocks, size=(nboots_block, nblocks))) - print("%{0}: Start block bootstrapping...".format(method)) + func_logger.info("Start block bootstrapping...") iterator_b = np.arange(nboots_block) if l_tqdm: iterator_b = tqdm(iterator_b) @@ -163,6 +218,281 @@ def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length return metric_boot +def get_domain_info(da: xr.DataArray, lonlat_dims: list =["lon", "lat"], re:float = 6378*1.e+03): + """ + Get information about the underlying grid of a DataArray. + Assumes a regular, spherical grid (can also be a rotated one of lonlat_dims are adapted accordingly) + :param da: The xrray DataArray given on a regular, spherical grid (i.e. providing latitude and longitude coordinates) + :param lonlat_dims: Names of the longutude and latitude coordinates + :param re: radius of spherical Earth + :return grid_dict: dictionary providing dx (grid spacing), nx (#gridpoints) and Lx(domain length) (same for y) as well as lat0 (central latitude) + """ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{get_domain_info.__name__}") + + lon, lat = da[lonlat_dims[0]], da[lonlat_dims[1]] + + try: + assert lon.ndim, f"Longitude data must be a 1D-array, but is a {lon.ndim:d}D-array" + except AssertionError as e: + func_logger.error(e, stack_info=True, exc_info=True) + raise e + + try: + assert lat.ndim, f"Latitude data must be a 1D-array, but is a {lat.ndim:d}D-array" + except AssertionError as e: + func_logger.error(e, stack_info=True, exc_info=True) + raise e + + lat0 = np.mean(lat) + nx, ny = len(lon), len(lat) + dx, dy = (lon[1] - lon[0]).values, (lat[1] - lat[0]).values + + deg2m = re*2*np.pi/360. + Lx, Ly = np.abs(nx*dx*deg2m)*np.cos(np.deg2rad(lat0)), np.abs(ny*dy*deg2m) + + grid_dict = {"nx": nx, "ny": ny, "dx": dx, "dy": dy, + "Lx": Lx, "Ly": Ly, "lat0": lat0} + + return grid_dict + + +def detrend_data(da: xr.DataArray, xy_dims: list =["lon", "lat"]): + """ + Detrend data on a limited area domain to majke it periodic in horizontal directions. + Method based on Errico, 1985. + :param da: The data given on a regular (spherical) grid + :param xy_dims: Names of horizontal dimensions + :return detrended, periodic data: + """ + + x_dim, y_dim = xy_dims[0], xy_dims[1] + nx, ny = len(da[x_dim]), len(da[y_dim]) + + # remove trend in x-direction + fac_x = xr.full_like(da, 1.) * xr.DataArray(2*np.arange(nx) - nx, dims=x_dim, coords={x_dim: da[x_dim]}) + fac_y = xr.full_like(da, 1.)* xr.DataArray(2*np.arange(ny) - ny, dims=y_dim, coords={y_dim: da[y_dim]}) + trend_x, _ = xr.broadcast((da.isel({x_dim: -1}) - da.isel({x_dim: 0}))/float(nx-1), da) + da = da - 0.5 * trend_x*fac_x + # remove trend in y-direction + trend_y, _ = xr.broadcast((da.isel({y_dim: -1}) - da.isel({y_dim: 0}))/float(ny-1), da) + da = da - 0.5 * trend_y*fac_y + + return da + + +def angular_integration(da_fft, grid_dict: dict, lcutoff: bool = True): + """ + Get power spectrum as a function of the total wavenumber. + The integration in the (k_x, k_y)-plane is done by summation over (k_x, k_y)-pairs lying in annular rings, cf. Durran et al., 2017 + :param da_fft: Fast Fourier transformed data with (lat, lon) as last two dimensions + :param grid_dict: dictionary providing information on underlying grid (generated by get_domain_info-method) + :param lcutoff: flag if spectrum should be truncated to cutoff frequency or if full spectrum should be returned (False) + :return power spectrum in total wavenumber space. + """ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{angular_integration.__name__}") + + sh = da_fft.shape + nx, ny = grid_dict["nx"], grid_dict["ny"] + dk = np.array([2.*np.pi/grid_dict["Lx"], 2.*np.pi/grid_dict["Ly"]]) + idkh = np.argmax(dk) + dkx, dky = dk + nmax = int(np.round(np.sqrt(np.square(nx*dkx) + np.square(ny*dky))/dk[idkh])) + + sh = (*sh[:-2], nmax) if da_fft.ndim >= 2 else (1, nmax) # add singleton dim if da_fft is a 2D-array only + spec_radial = np.zeros(sh, dtype="float") + + # start angular integration + func_logger.info(f"Start angular integration for {nmax:d} wavenumbers.") + for i in range(nx): + for j in range(ny): + k_now = int(np.round(np.sqrt(np.square(i*dkx) + np.square(j*dky))/dk[idkh])) + spec_radial[..., k_now] += da_fft[..., i, j].real**2 + da_fft[..., i, j].imag**2 + + if lcutoff: + # Cutting/truncating the spectrum is required to ensure that the combinations of kx and ky + # to yield the total wavenumber are complete. Without this truncation, the spectrum gets distorted + # as argued in Errico, 1985 (cf. Eq. 6 therein). + # Note that Errico, 1985 chooses the maximum of (nx, ny) since dk is defined as dk=min(kx, ky). + # Here, we choose dk = max(dkx, dky) following Durran et al., 2017 (see Eq. 18 therein), and thus have to choose min(nx, ny). + cutoff = int(np.round(min(np.array([nx, ny]))/2 + 0.01)) + spec_radial = spec_radial[..., :cutoff] + + return np.squeeze(spec_radial) + + +def get_spectrum(da: xr.DataArray, lonlat_dims = ["lon", "lat"], lcutoff: bool = True, re: float = 6378*1e+03): + """ + Compute power spectrum in terms of total wavenumber from numpy-array. + Note: Assumes a regular, spherical grid with dx=dy. + :param da: DataArray with (lon, lat)-like dimensions + :param lcutoff: flag if spectrum should be truncated to cutoff frequency or if full spectrum should be returned (False) + :param re: (spherical) Earth radius + :return var_rad: power spectrum in terms of wavenumber + """ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{get_spectrum.__name__}") + + # sanity check on # dimensions + grid_dict = get_domain_info(da, lonlat_dims, re=re) + nx, ny = grid_dict["nx"], grid_dict["ny"] + lx, ly = grid_dict["Lx"], grid_dict["Ly"] + + da = da.transpose(..., *lonlat_dims) + # detrend data to get periodic boundary values (cf. Errico, 1985) + da_var = detrend_data(da, xy_dims=lonlat_dims) + # ... and apply FFT + func_logger.info("Start computing Fourier transformation") + fft_var = np.fft.fft2(da_var.values)/float(nx*ny) + + var_rad = angular_integration(fft_var, {"nx": nx, "ny": ny, "Lx": lx, "Ly": ly}, lcutoff) + + return var_rad + + +def sample_permut_xyt(da_orig: xr.DataArray, patch_size:tuple = (6, 6)): + """ + Permutes sample in a spatio-temporal way following the method of Breiman (2001). + The concrete implementation follows Höhlein et al., 2020 with spatial permutation based on patching. + Note that the latter allows to handle time-invariant data. + :param da_orig: original sample. Must be 3D with a 'time'-dimension + :param patch_size: tuple for patch size + :return: spatio-temporally permuted sample + """ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{sample_permut_xyt.__name__}") + + try: + assert da_orig.ndim == 3, f"da_orig must be a 3D-array, but has {da_orig.ndim} dimensions." + except AssertionError as e: + func_logger.error(e, stack_info=True, exc_info=True) + raise e + + coords_orig = da_orig.coords + dims_orig = da_orig.dims + sh_orig = da_orig.shape + + # temporal permutation + func_logger.info(f"Start spatio-temporal permutation for sample with shape {sh_orig}.") + + ntimes = len(da_orig["time"]) + if dims_orig != "time": + da_orig = da_orig.transpose("time", ...) + coords_now, dims_now, sh_now = da_orig.coords, da_orig.dims, da_orig.shape + else: + coords_now, dims_now, sh_now = coords_orig, dims_orig, sh_orig + + da_permute = np.random.permutation(da_orig).copy() + da_permute = xr.DataArray(da_permute, coords=coords_now, dims=dims_now) + + # spatial permutation with patching + # time must be last dimension (=channel dimension) + sh_aux = da_permute.transpose(..., "time").shape + + # Note that the order of x- and y-coordinates does not matter here + da_patched = view_as_blocks(da_permute.transpose(..., "time").values, block_shape=(*patch_size, ntimes)) + + # convert to DataArray + sh = da_patched.shape + dims = ["pat_x", "pat_y", "dummy", "ix", "iy", "time"] + + da_patched = xr.DataArray(da_patched, coords={dims[0]: np.arange(sh[0]), dims[1]: np.arange(sh[1]), "dummy": range(1), + dims[3]: np.arange(sh[3]), dims[4]: np.arange(sh[4]), "time": da_permute["time"]}, + dims=dims) + + # stack xy-patches and permute + da_patched = da_patched.stack({"pat_xy": ["pat_x", "pat_y"]}) + da_patched[...] = np.random.permutation(da_patched.transpose()).transpose() + + # unstack + da_patched = da_patched.unstack().transpose(*dims) + + # revert view_as_blocks-opertaion + da_patched = da_patched.values.transpose([0, 3, 1, 4, 2, 5]).reshape(sh_aux) + + # write data back on da_permute + da_permute[...] = np.moveaxis(da_patched, 2, 0) + + # transpose to original dimension ordering if required + da_permute = da_permute.transpose(*dims_orig) + + return da_permute + + +def feature_importance(da: xr.DataArray, predictors: list_or_str, varname_tar: str, model, norm, score_name: str, + data_loader_opt: dict, patch_size = (6, 6), variable_dim = "variable"): + """ + Run featiure importance analysis based on permutation method (see signature of sample_permut_xyt-method) + :param da: The (test-)data provided in DataArray with variable-dimension + :param predictors: List of predictor variables + :param varname_tar: Name of target variable + :param model: Trained model for inference + :param norm: Normalization object + :param score_name: Name of metric-score to be calculated + :param data_loader_opt: Dictionary providing options for the TensorFlow data pipeline + :param patch_size: Tuple for patch size during spatio-temporal permutation + :param variable_dim: Name of variable dimension + :return score_all: DataArray with scores for all predictor variables + """ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{feature_importance.__name__}") + + # sanity checks + _ = check_str_in_list(list(da[variable_dim]), predictors) + try: + assert da.dims[0] == "time", f"First dimension of the data must be a time-dimensional, but is {da.dims[0]}." + except AssertionError as e: + func_logger.error(e, stack_info=True, exc_info=True) + raise e + + ntimes = len(da["time"]) + + # get ground truth data and underlying metadata + ground_truth = norm.denormalize(da.sel({variable_dim: varname_tar}), varname=varname_tar) + + # initialize score-array + score_all = xr.DataArray(np.zeros((len(predictors), ntimes)), coords={"predictor": predictors, "time": da["time"]}, + dims=["predictor", "time"]) + + for var in predictors: + func_logger.info(f"Run sample importance analysis for {var}...") + # get copy of sample array + da_copy = da.copy(deep=True) + # permute sample + da_permut = sample_permut_xyt(da.sel({variable_dim: var}).copy(), patch_size=patch_size) + da_copy.loc[{variable_dim: var}] = da_permut + + # get TF dataset + func_logger.info(f"Set-up data pipeline with permuted sample for {var}...") + tfds_test = HandleDataClass.make_tf_dataset_allmem(da_copy, **data_loader_opt) + + # predict + func_logger.info(f"Run inference with permuted sample for {var}...") + y_pred = model.predict(tfds_test, verbose=2) + + # convert to xarray + y_pred = convert_to_xarray(y_pred, norm, varname_tar, ground_truth.coords, ground_truth.dims, True) + + # calculate score + func_logger.info(f"Calculate score for permuted samples of {var}...") + score_engine = Scores(y_pred, ground_truth, dims=ground_truth.dims[1::]) + score_all.loc[{"predictor": var}] = score_engine(score_name) + + #free_mem([da_copy, da_permut, tfds_test, y_pred, score_engine]) + + return score_all + + + + + + + + + + + class Scores: """ Class to calculate scores and skill scores. @@ -224,16 +554,121 @@ class Scores: def avg_dims(self, dims): if dims is None: self.avg_dims = self.data_dims - print("Scores will be averaged across all data dimensions.") + # print("Scores will be averaged across all data dimensions.") else: dim_stat = [avg_dim in self.data_dims for avg_dim in dims] if not all(dim_stat): - ind_bad = [i for i, x in enumerate(dim_stat) if x] - raise ValueError("The following dimensions for score-averaging are not" + - "part of the data: {0}".format(", ".join(dims[ind_bad]))) + ind_bad = [i for i, x in enumerate(dim_stat) if not x] + raise ValueError("The following dimensions for score-averaging are not " + + "part of the data: {0}".format(", ".join(np.array(dims)[ind_bad]))) self._avg_dims = dims + def get_2x2_event_counts(self, thresh): + """ + Get counts of 2x2 contingency tables + :param thres: threshold to define events + :return: (a, b, c, d)-tuple of 2x2 contingency table + """ + a = ((self.data_fcst >= thresh) & (self.data_ref >= thresh)).sum(dim=self.avg_dims) + b = ((self.data_fcst >= thresh) & (self.data_ref < thresh)).sum(dim=self.avg_dims) + c = ((self.data_fcst < thresh) & (self.data_ref >= thresh)).sum(dim=self.avg_dims) + d = ((self.data_fcst < thresh) & (self.data_ref < thresh)).sum(dim=self.avg_dims) + + return a, b, c, d + + def calc_ets(self, thresh=0.1): + """ + Calculates Equitable Threat Score (ETS) on data. + :param thres: threshold to define events + :return: ets-values + """ + a, b, c, d = self.get_2x2_event_counts(thresh) + n = a + b + c + d + ar = (a + b)*(a + c)/n # random reference forecast + + denom = (a + b + c - ar) + + ets = (a - ar)/denom + ets = ets.where(denom > 0, np.nan) + + return ets + + def calc_fbi(self, thresh=0.1): + """ + Calculates Frequency bias (FBI) on data. + :param thres: threshold to define events + :return: fbi-values + """ + a, b, c, d = self.get_2x2_event_counts(thresh) + + denom = a+c + fbi = (a + b)/denom + + fbi = fbi.where(denom > 0, np.nan) + + return fbi + + def calc_pss(self, thresh=0.1): + """ + Calculates Peirce Skill Score (PSS) on data. + :param thres: threshold to define events + :return: pss-values + """ + a, b, c, d = self.get_2x2_event_counts(thresh) + + denom = (a + c)*(b + d) + pss = (a*d - b*c)/denom + + pss = pss.where(denom > 0, np.nan) + + return pss + + def calc_l1(self, **kwargs): + """ + Calculate the L1 error norm of forecast data w.r.t. reference data. + L1 will be divided by the number of samples along the average dimensions. + Similar to MAE, but provides just a number divided by number of samples along average dimensions. + :return: L1-error + """ + if kwargs: + print("Passed keyword arguments to calc_l1 are without effect.") + + l1 = np.sum(np.abs(self.data_fcst - self.data_ref)) + + len_dims = np.array([self.data_fcst.sizes[dim] for dim in self.avg_dims]) + l1 /= np.prod(len_dims) + + return l1 + + def calc_l2(self, **kwargs): + """ + Calculate the L2 error norm of forecast data w.r.t. reference data. + Similar to RMSE, but provides just a number divided by number of samples along average dimensions. + :return: L2-error + """ + if kwargs: + print("Passed keyword arguments to calc_l2 are without effect.") + + l2 = np.sum(np.square(self.data_fcst - self.data_ref)) + + len_dims = np.array([self.data_fcst.sizes[dim] for dim in self.avg_dims]) + l2 /= np.prod(len_dims) + + return l2 + + def calc_mae(self, **kwargs): + """ + Calculate mean absolute error (MAE) of forecast data w.r.t. reference data + :return: MAE averaged over provided dimensions + """ + if kwargs: + print("Passed keyword arguments to calc_mae are without effect.") + + mae = np.abs(self.data_fcst - self.data_ref).mean(dim=self.avg_dims) + + return mae + def calc_mse(self, **kwargs): """ Calculate mse of forecast data w.r.t. reference data @@ -251,6 +686,25 @@ class Scores: rmse = np.sqrt(self.calc_mse(**kwargs)) return rmse + + def calc_acc(self, clim_mean: xr.DataArray, spatial_dims: List = ["lat", "lon"]): + """ + Calculate anomaly correlation coefficient (ACC). + :param clim_mean: climatological mean of the data + :param spatial_dims: names of spatial dimensions over which ACC are calculated. + Note: No averaging is possible over these dimensions. + :return acc: Averaged ACC (except over spatial_dims) + """ + + fcst_ano, obs_ano = self.data_fcst - clim_mean, self.data_ref - clim_mean + + acc = (fcst_ano*obs_ano).sum(spatial_dims)/np.sqrt(fcst_ano.sum(spatial_dims)*obs_ano.sum(spatial_dims)) + + mean_dims = [x for x in self.avg_dims if x not in spatial_dims] + if len(mean_dims) > 0: + acc = acc.mean(mean_dims) + + return acc def calc_bias(self, **kwargs): @@ -297,6 +751,131 @@ class Scores: ratio_spat_variability = ratio_spat_variability.mean(dim=avg_dims) return ratio_spat_variability + + def calc_iqd(self, xnodes=None, nh=0, lfilter_zero=True): + """ + Calculates squared integrated distance between simulation and observational data. + Method: Retrieves the empirical CDF, calculates CDF(xnodes) for both data sets and + then uses the squared differences at xnodes for trapezodial integration. + Note, that xnodes should be selected in a way, that CDF(xvalues) increases + more or less continuously by ~0.01 - 0.05 for increasing xnodes-elements + to ensure accurate integration. + + :param data_simu : 1D-array of simulation data + :param data_obs : 1D-array of (corresponding) observational data + :param xnodes : x-values used as nodes for integration (optional, automatically set if not given + according to stochastic properties of precipitation data) + :param nh : accumulation period (affects setting of xnodes) + :param lfilter_zero: Flag to filter out zero values from CDF calculation + :return: integrated quadrated distance between CDF of data_simu and data_obs + """ + + data_simu = self.data_fcst.values.flatten() + data_obs = self.data_ref.values.flatten() + + # Scarlet: because data_simu and data_obs are flattened anyway this is not needed + #if np.ndim(data_simu) != 1 or np.ndim(data_obs) != 1: + # raise ValueError("Input data arrays must be 1D-arrays.") + + if xnodes is None: + if nh == 1: + xnodes = [0., 0.005, 0.01, 0.015, 0.025, 0.04, 0.06, 0.08, 0.1, 0.13, 0.16, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, + 0.8, 1., 1.25, 1.5, 1.8, 2.4, 3., 3.75, 4.5, 5.25, 6., 7., 9., 12., 20., 30., 50.] + elif 1 < nh <= 6: + ### obtained manually based on observational data between May and July 2017 + ### except for the first step and the highest node-values, + ### CDF is increased by 0.03 - 0.05 with every step ensuring accurate integration + xnodes = [0.00, 0.005, 0.01, 0.02, 0.04, 0.07, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9, 1.15, 1.5, 1.9, 2.4, + 3., 4., 5., 6., 7.5, 10., 15., 25., 40., 60.] + else: + xnodes = [0.00, 0.01, 0.02, 0.035, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1., 1.3, 1.7, 2.1, 2.5, + 3., 3.5, 4., 4.75, 5.5, 6.3, 7.1, 8., 9., 10., 12.5, 15., 20., 25., 35., 50., 70., 100.] + + data_simu_filt = data_simu[~np.isnan(data_simu)] + data_obs_filt = data_obs[~np.isnan(data_obs)] + if lfilter_zero: + data_simu_filt = np.sort(data_simu_filt[data_simu_filt > 0.]) + data_obs_filt = np.sort(data_obs_filt[data_obs_filt > 0.]) + else: + data_simu_filt = np.sort(data_simu_filt) + data_obs_filt = np.sort(data_obs_filt) + + nd_points_simu = np.shape(data_simu_filt)[0] + nd_points_obs = np.shape(data_obs_filt)[0] + + prob_simu = 1. * np.arange(nd_points_simu)/ (nd_points_simu - 1) + prob_obs = 1. * np.arange(nd_points_obs)/ (nd_points_obs -1) + + cdf_simu = get_cdf_of_x(data_simu_filt,prob_simu) + cdf_obs = get_cdf_of_x(data_obs_filt,prob_obs) + + yvals_simu = cdf_simu(xnodes) + yvals_obs = cdf_obs(xnodes) + + if yvals_obs[-1] < 0.999: + print("CDF of last xnodes {0:5.2f} for observation data is smaller than 99.9%." + + "Consider setting xnodes manually!") + + if yvals_simu[-1] < 0.999: + print("CDF of last xnodes {0:5.2f} for simulation data is smaller than 99.9%." + + "Consider setting xnodes manually!") + + # finally, perform trapezodial integration + return np.trapz(np.square(yvals_obs - yvals_simu), xnodes) + + def calc_seeps(self, seeps_weights: xr.DataArray, t1: xr.DataArray, t3: xr.DataArray, spatial_dims: List): + """ + Calculates stable equitable error in probabiliyt space (SEEPS), see Rodwell et al., 2011 + :param seeps_weights: SEEPS-parameter matrix to weight contingency table elements + :param t1: threshold for light precipitation events + :param t3: threshold for strong precipitation events + :param spatial_dims: list/name of spatial dimensions of the data + :return seeps skill score (i.e. 1-SEEPS) + """ + + def seeps(data_ref, data_fcst, thr_light, thr_heavy, seeps_weights): + ob_ind = (data_ref > thr_light).astype(int) + (data_ref >= thr_heavy).astype(int) + fc_ind = (data_fcst > thr_light).astype(int) + (data_fcst >= thr_heavy).astype(int) + indices = fc_ind * 3 + ob_ind # index of each data point in their local 3x3 matrices + seeps_val = seeps_weights[indices, np.arange(len(indices))] # pick the right weight for each data point + + return 1.-seeps_val + + if self.data_fcst.ndim == 3: + assert len(spatial_dims) == 2, f"Provide two spatial dimensions for three-dimensional data." + data_fcst, data_ref = self.data_fcst.stack({"xy": spatial_dims}), self.data_ref.stack({"xy": spatial_dims}) + seeps_weights = seeps_weights.stack({"xy": spatial_dims}) + t3 = t3.stack({"xy": spatial_dims}) + lstack = True + elif self.data_fcst.ndim == 2: + data_fcst, data_ref = self.data_fcst, self.data_ref + lstack = False + else: + raise ValueError(f"Data must be a two-or-three-dimensional array.") + + # check dimensioning of data + assert data_fcst.ndim <= 2, f"Data must be one- or two-dimensional, but has {data_fcst.ndim} dimensions. Check if stacking with spatial_dims may help." + + if data_fcst.ndim == 1: + seeps_values_all = seeps(data_ref, data_fcst, t1.values, t3, seeps_weights) + else: + data_fcst, data_ref = data_fcst.transpose(..., "xy"), data_ref.transpose(..., "xy") + seeps_values_all = xr.full_like(data_fcst, np.nan) + seeps_values_all.name = "seeps" + for it in range(data_ref.shape[0]): + data_fcst_now, data_ref_now = data_fcst[it, ...], data_ref[it, ...] + # in case of missing data, skip computation + if np.all(np.isnan(data_fcst_now)) or np.all(np.isnan(data_ref_now)): + continue + + seeps_values_all[it,...] = seeps(data_ref_now, data_fcst_now, t1.values, t3, seeps_weights.values) + + if lstack: + seeps_values_all = seeps_values_all.unstack() + + seeps_values = seeps_values_all.mean(dim=self.avg_dims) + + return seeps_values @staticmethod def calc_geo_spatial_diff(scalar_field: xr.DataArray, order: int = 1, r_e: float = 6371.e3, dom_avg: bool = True): diff --git a/downscaling_ap5/preprocess/add_invariant_data_atmorep.ipynb b/downscaling_ap5/preprocess/add_invariant_data_atmorep.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d8f6d8ad0f2cf6906c612b586556115317dbae15 --- /dev/null +++ b/downscaling_ap5/preprocess/add_invariant_data_atmorep.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ceeb7395-688e-43ee-9aa4-55818a592555", + "metadata": { + "tags": [] + }, + "source": [ + "# Notebook to add constant variables to competing AtmoRep downscaling data\n", + "\n", + "This Notebook processes the files generated with `preprocees_data_atmorep.sh` to add the surface topography from ERA5 and COSMO REA6 data which both constitute invariant fields, but have to be expanded to include a time-dimension." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7426cf19-7ff9-4358-8615-57164e7c7f9e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "! pip install findlibs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2b2f8f1-7112-42ac-bf48-8018eb7e4b1d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import glob\n", + "from tqdm import tqdm\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import xarray as xr\n", + "import cfgrib" + ] + }, + { + "cell_type": "markdown", + "id": "8b6adba5-71f6-4ab3-b866-c76fb02cedeb", + "metadata": { + "tags": [] + }, + "source": [ + "Parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2220e62-6873-4fb4-86e9-2c40a3b519ac", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "data_dir=\"/p/scratch/deepacf/maelstrom/maelstrom_data/ap5/competing_atmorep/\"\n", + "invar_file_era5 = \"/p/scratch/deepacf/maelstrom/maelstrom_data/ap5/competing_atmorep/reanalysis_orography.nc\"\n", + "invar_file_crea6 = \"/p/scratch/atmo-rep/data/cosmo_rea6/static/cosmo_rea6_orography.nc\"" + ] + }, + { + "cell_type": "markdown", + "id": "afb9b245-9ecf-4c91-8103-de4811240e0e", + "metadata": { + "tags": [] + }, + "source": [ + "The file 'invar_file_era5' has been generated with the following CDO-command:\n", + "``` \n", + "cdo --reduce_dim -t ecmwf -f nc copy -remapbil,~/downscaling_maelstrom/downscaling_jsc_repo/downscaling_ap5/grid_des/crea6_reg_grid reanalysis_orography.grib reanalysis_orography.nc\n", + "``` \n", + "where the original grib-file was obatined from AtmoRep (```/p/scratch/atmo-rep/data/era5/static```)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cacd266c-6e0e-476d-9805-f4fac289e3cf", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "file_list = glob.glob(os.path.join(data_dir, \"downscaling_atmorep*.nc\"))\n", + "\n", + "if len(file_list) == 0:\n", + " raise FileNotFoundError(f\"Could not find any datafiles under '{data_dir}'...\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c9bb25b-4930-47db-8f9d-a6e1cbc59571", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "ds_invar_era5 = xr.open_dataset(invar_file_era5)\n", + "ds_invar_crea6 = xr.open_dataset(invar_file_crea6).sel({\"lat\": ds_invar_era5[\"lat\"], \"lon\": ds_invar_era5[\"lon\"]})\n", + "ds_invar_crea6 = ds_invar_crea6.drop_vars(\"FR_LAND\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ebb63ef-9dd0-49e1-be15-560f87c51166", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "for f in tqdm(file_list):\n", + " # read current file\n", + " print(f\"Process data-file '{f}'...\")\n", + " ds_now = xr.open_dataset(f)\n", + " var_list = list(ds_now.data_vars)\n", + " lchange = False\n", + " \n", + " if \"z_in\" not in var_list:\n", + " print(f\"Add surface topography from ERA5...\")\n", + " dst = ds_invar_era5.expand_dims(time=ds_now[\"time\"])\n", + " dst = dst.rename({\"Z\": \"z_in\"})\n", + " \n", + " ds_all = xr.merge([ds_now, dst])\n", + " lchange = True\n", + " \n", + " if \"hsurf_tar\" not in var_list:\n", + " print(f\"Add surface topography from CREA6...\")\n", + " dst = ds_invar_crea6.expand_dims(time=ds_now[\"time\"])\n", + " dst = dst.rename({\"z\": \"hsurf_tar\"})\n", + " \n", + " ds_all = xr.merge([ds_all , dst])\n", + " lchange = True\n", + " \n", + " if \"t2m_ml0_tar\" in var_list:\n", + " ds_all = ds_all.rename({\"t2m_ml0_tar\": \"t2m_tar\"})\n", + " lchange = True\n", + " \n", + " if lchange:\n", + " print(f\"Write modified dataset back to '{f}'...\")\n", + " ds_all.to_netcdf(f.replace(\".nc\", \"_new.nc\"))\n", + " else:\n", + " print(f\"No changes to data from '{f}' applied. Continue...\")\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91ceb153-5900-428f-a097-fd9aab19c669", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langguth1_downscaling_kernel", + "language": "python", + "name": "langguth1_downscaling_kernel" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/downscaling_ap5/preprocess/download_era5_data.py b/downscaling_ap5/preprocess/download_era5_data.py new file mode 100644 index 0000000000000000000000000000000000000000..caffe444b5b2ba30d3c1397a2f7cfdc5ba441dbd --- /dev/null +++ b/downscaling_ap5/preprocess/download_era5_data.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: 2022 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) +# +# SPDX-License-Identifier: MIT + +""" +Script to download ERA5 data from the CDS API. +""" + +__author__ = "Michael Langguth" +__email__ = "m.langguth@fz-juelich.de" +__date__ = "2023-11-21" +__update__ = "2023-08-22" + +# import modules +import os, sys +import logging +import cdsapi +import numpy as np +import pandas as pd +from multiprocessing import Pool +from other_utils import to_list + +# get logger +logger_module_name = f"main_download_era5.{__name__}" +module_logger = logging.getLogger(logger_module_name) + +# known request parameters +knwon_req_keys = ["ml", "sfc"] + + +class ERA5_Data_Loader(object): + """ + Class to download ERA5 data from the CDS API. + """ + + knwon_req_keys = ["ml", "sfc"] + allowed_formats = ["netcdf", "grib"] + # defaults + area = [75, -45, 20, 65] + month_start = 1 + month_end = 12 + + + def __init__(self, nworkers) -> None: + + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{self.__init__.__name__}") + + self.nworkers = nworkers + try: + self.cds = cdsapi.Client() + except Exception as e: + func_logger.error(f"Could not initialize CDS API client: {e} \n" + \ + "Please follow the instructions at https://cds.climate.copernicus.eu/api-how-to to install the CDS API.") + raise e + + def __call__(self, req_dict, data_dir, start, end, format, **kwargs): + """ + Run the requests to download the ERA5 data. + :param req_dict: dictionary with data requests + :param data_dir: directory where output data files will be stored + :param start: start year of data request + :param end: end year of data request + :param format: format of downloaded data (netcdf or grib) + :param kwargs: additional keyword arguments (options: area, month_start, month_end) + """ + + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{self.__call__.__name__}") + + # create output directory + if not os.path.exists(data_dir): + func_logger.info(f"Creating output directory {data_dir}") + os.makedirs(data_dir) + + # validate request keys + req_list = self.validate_request_keys(req_dict) + + # select data format + if format not in self.allowed_formats: + func_logger.warning(f"Unknown data format {format}. Using default format netcdf.") + format = "netcdf" + + for req_key in req_list: + if req_key == "sfc": + out = self.download_sfc_data(req_dict[req_key], data_dir, start, end, format, **kwargs) + elif req_key == "ml": + out = self.download_ml_data(req_dict[req_key], data_dir, start, end, format, **kwargs) + + # check if all requests were successful + _ = self.check_request_result(out) + + return out + + + def download_sfc_data(self, varlist, data_dir, start, end, format, **kwargs): + """ + Download ERA5 surface data. + To obtain the varlist, please refer to the CDS API documentation at https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-single-levels?tab=form + :param varlist: list of variables to download + :param data_dir: directory where output data files will be stored + :param start: start year of data request + :param end: end year of data request + :param format: format of downloaded data (netcdf or grib) + :param kwargs: additional keyword arguments (options: area, month_start, month_end) + :return: output of multiprocessing pool + """ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{self.download_sfc_data.__name__}") + + # get additional keyword arguments + area = kwargs.get("area", self.area) + month_start = kwargs.get("month_start", self.month_start) + month_end = kwargs.get("month_end", self.month_end) + + # create base request dictionary (None-values will be set dynamically) + req_dict_base = {"product_type": "reanalysis", "format": f"{format}", + "variable": to_list(varlist), + "day": None, "month": None, "time": [f"{h:02d}" for h in range(24)], "year": None, + "area": area} + + # initialize multiprocessing pool + func_logger.info(f"Downloading ERA5 surface data for variables {', '.join(varlist)} with {self.nworkers} workers.") + pool = Pool(self.nworkers) + + # initialize dictionary for request results + req_results = {} + + # create data requests for each month + for year in range(start, end+1): + req_dict = req_dict_base.copy() + req_dict["year"] = [f"{year}"] + for month in range(month_start, month_end+1): + req_dict["month"] = [f"{month:02d}"] + # get last day of month + last_day = pd.Timestamp(year, month, 1) + pd.offsets.MonthEnd(1) + req_dict["day"] = [f"{d:02d}" for d in range(1, last_day.day+1)] + fout = f"era5_sfc_{year}-{month:02d}.{format}" + + func_logger.debug(f"Downloading ERA5 surface data for {year}-{month:02d} to {os.path.join(data_dir, fout)}") + + req_results[fout] = pool.apply_async(self.cds.retrieve, args=("reanalysis-era5-single-levels", req_dict, + os.path.join(data_dir, fout))) + + # run and close multiprocessing pool + pool.close() + pool.join() + + func_logger.info(f"Finished downloading ERA5 surface data.") + + return req_results + + def download_ml_data(self, var_dict, data_dir, start, end, format, **kwargs): + """ + Download ERA5 data for multiple levels. + To obtain the varlist, please refer to the CDS API documentation at https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-complete?tab=form + :param var_dict: dictionary of variables to download for each level + :param data_dir: directory where output data files will be stored + :param start: start year of data request + :param end: end year of data request + :param format: format of downloaded data (netcdf or grib) + :param kwargs: additional keyword arguments (options: area, month_start, month_end) + :return: output of multiprocessing pool + """ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{self.download_ml_data.__name__}") + + # get additional keyword arguments + area = kwargs.get("area", self.area) + month_start = kwargs.get("month_start", self.month_start) + month_end = kwargs.get("month_end", self.month_end) + + vars = var_dict.keys() + vars_param = self.translate_mars_vars(vars) + # ensure that data is downloaded for all levels (All-together approach -> overhead: additional download of data for levels that are not needed) + collector = [] + _ = [collector.extend(ml_list) for ml_list in var_dict.values()] + all_lvls = sorted([str(lvl) for lvl in set(collector)]) + + # create base request dictionary (None-values will be set dynamically) + req_dict_base = {"class": "ea", "date": None, + "expver": "1", + "levelist": "/".join(all_lvls), + "levtype": "ml", + "param": vars_param, + "stream": "oper", + "time": "00/to/23/by/1", + "type": "an", + "area": area , + "grid": "0.25/0.25",} + + # initialize multiprocessing pool + func_logger.info(f"Downloading ERA5 model-level data for variables {', '.join(vars)} with {self.nworkers} workers.") + pool = Pool(self.nworkers) + + # initialize dictionary for request results + req_results = {} + + # create data requests for each month + for year in range(start, end+1): + req_dict = req_dict_base.copy() + for month in range(month_start, month_end+1): + # get last day of month + last_day = pd.Timestamp(year, month, 1) + pd.offsets.MonthEnd(1) + req_dict["date"] = f"{year}-{month:02d}-01/to/{year}-{month:02d}-{last_day.day}" + fout = f"era5_ml_{year}-{month:02d}.{format}" + + func_logger.debug(f"Downloading ERA5 model-level data for {year}-{month:02d} to {os.path.join(data_dir, fout)}") + + req_results[fout] = pool.apply_async(self.cds.retrieve, args=("reanalysis-era5-complete", req_dict, + os.path.join(data_dir, fout))) + + # run and close multiprocessing pool + pool.close() + pool.join() + + func_logger.info(f"Finished downloading ERA5 model-level data.") + + return req_results + + def check_request_result(self, results_dict): + """ + Check if request was successful. + :param results_dict: dictionary with request results (returned by download_ml_data and download_sfc_data) + """ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{self.check_request_result.__name__}") + + # check if all requests were successful + stat = [o.get().__dict__["reply"]["state"] == "completed" for o in results_dict.values()] + + ok = True + + if all(stat): + func_logger.info(f"All requests were successful.") + else: + ok = False + bad_req = np.where(np.array(stat) == False)[0] + results_arr = np.array(list(results_dict.keys())) + func_logger.error(f"The following requests were not successful: {', '.join(results_arr[bad_req])}.") + + return ok + + + def validate_request_keys(self, req_dict): + """ + Validate request keys in data request file. + :param req_dict: dictionary with data requests + :return: list of valid request keys (filtered) + """ + # get local logger + func_logger = logging.getLogger(f"{logger_module_name}.{self.validate_request_keys.__name__}") + + # create list of requests + req_list = [] + for req_key in req_dict.keys(): + if req_key in self.knwon_req_keys: + req_list.append(req_key) + else: + func_logger.warning(f"Unknown request key {req_key} in data request file. Skipping this request.") + + return req_list + + def translate_mars_vars(self, vars): + """ + Translate variable names to MARS parameter names. + :param vars: list of variable names + :return: list of MARS parameter names + """ + + # create dictionary with variable names and corresponding MARS parameter names + var_dict = {"z": "129", "t": "130", "u": "131", "v": "132", "w": "135", "q": "133", + "temperature": "130", "specific_humidity": "133", "geopotential": "129", + "vertical_velocity": "135", "u_component_of_wind": "131", "v_component_of_wind": "132", + "vorticity": "138", "divergence": "139", "logarithm_of_surface_pressure": "152", + "fraction_of_cloud_cover": "164", "specific_cloud_liquid_water_content": "246", + "specific_cloud_ice_water_content": "247", "specific_rain_water_content": "248", + "specific_snow_water_content": "249", } + + # translate variable names + vars_param = [var_dict[var] for var in vars] + + return vars_param + + + + + diff --git a/downscaling_ap5/preprocess/preprocess_data_atmorep.sh b/downscaling_ap5/preprocess/preprocess_data_atmorep.sh new file mode 100755 index 0000000000000000000000000000000000000000..1b406a6d8b8f75b594970b76d1e347015c58e2d3 --- /dev/null +++ b/downscaling_ap5/preprocess/preprocess_data_atmorep.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +############################################################################## +# Script to process the ERA5 and COSMO REA6 data that has also been used # +# in the AtMoRep project. To be processable for the data loader in MAELSTROM,# +# the input and target data must be provided in monthly netCDF files. # +# Furthermore, input and target variables must be denoted with _in and _tar, # +# respectively. # +############################################################################## + +# author: Michael Langguth +# date: 2023-08-16 +# update: 2023-08-16 + +# parameters + +era5_basedir=/p/scratch/atmo-rep/data/era5/ml_levels/ +crea6_basedir=/p/scratch/atmo-rep/data/cosmo_rea6/ml_levels/ +output_dir=/p/scratch/deepacf/maelstrom/maelstrom_data/ap5/competing_atmorep/ +crea6_gdes=../grid_des/crea6_reg_grid + +era5_vars=("t") +era5_vars_full=("temperature") +crea6_vars=("t2m") +crea6_vars_full=("t_2m") + +ml_lvl_era5=( 96 105 114 123 137 ) +ml_lvl_crea6=( 0 ) + +year_start=1995 +year_end=2018 + +# main + +echo "Loading required modules..." +ml purge +ml Stages/2022 GCC/11.2.0 OpenMPI/4.1.2 NCO/5.0.3 CDO/2.0.2 + +tmp_dir=${output_dir}/tmp + +# create output directory +if [ ! -d $output_dir ]; then + mkdir -p $output_dir +fi + +if [ ! -d $tmp_dir ]; then + mkdir -p $tmp_dir +fi + +# loop over years and months +for yr in $(eval echo "{$year_start..$year_end}"); do + for mm in {01..12}; do + echo "Processing ERA5 data for ${yr}-${mm}..." + for ivar in "${!era5_vars[@]}"; do + echo "Processing variable ${era5_vars[ivar]} from ERA5..." + for ml_lvl in ${ml_lvl_era5[@]}; do + echo "Processing data for level ${ml_lvl}" + # get file name + era5_file=${era5_basedir}/${ml_lvl}/"${era5_vars_full[ivar]}"/reanalysis_"${era5_vars_full[ivar]}"_y${yr}_m${mm}_ml${ml_lvl}.grib + tmp_file1=${tmp_dir}/era5_${era5_vars[ivar]}_y${yr}_m${mm}_ml${ml_lvl}_tmp1.nc + tmp_file2=${tmp_dir}/era5_${era5_vars[ivar]}_y${yr}_m${mm}_ml${ml_lvl}_tmp2.nc + tmp_era5=${tmp_dir}/era5_${era5_vars[ivar]}_y${yr}_m${mm}_ml${ml_lvl}_tmp3.nc + # convert to netCDF and slice to region of interest + cdo -f nc copy -sellonlatbox,-1.5,25.75,42.25,56 ${era5_file} ${tmp_file1} + cdo remapbil,${crea6_gdes} ${tmp_file1} ${tmp_file2} + # rename variable + cdo --reduce_dim chname,${era5_vars[ivar]},${era5_vars[ivar]}_ml${ml_lvl}_in -selvar,${era5_vars[ivar]} ${tmp_file2} ${tmp_era5} + # clean-up + rm ${tmp_file1} ${tmp_file2} + done + done + + era5_file_now=${tmp_dir}/era5_y${yr}_m${mm}.nc + cdo merge ${tmp_dir}/*.nc ${era5_file_now} + # clean-up + rm ${tmp_dir}/era5*tmp*.nc + + echo "Processing COSMO-REA6 data for ${yr}-${mm}..." + for ivar in ${!crea6_vars[@]}; do + for ml_lvl in ${ml_lvl_crea6}; do + crea6_file=${crea6_basedir}/${ml_lvl}/${crea6_vars[ivar]}/cosmo_rea6_${crea6_vars[ivar]}_y${yr}_m${mm}_ml${ml_lvl}.nc + tmp_crea6=${tmp_dir}/crea6_${crea6_vars[ivar]}_y${yr}_m${mm}_ml${ml_lvl}_tmp_crea6.nc + cdo --reduce_dim -chname,${crea6_vars_full[ivar]},${crea6_vars[ivar]}_ml${ml_lvl}_tar -sellonlatbox,-1.25,25.6875,42.3125,55.75 -selvar,${crea6_vars_full[ivar]} ${crea6_file} ${tmp_crea6} + done + done + + crea6_file_now=${tmp_dir}/crea6_y${yr}_m${mm}.nc + cdo merge ${tmp_dir}/*tmp_crea6.nc ${crea6_file_now} + + cdo merge ${crea6_file_now} ${era5_file_now} ${output_dir}/downscaling_atmorep_train_${yr}-${mm}.nc; + + # clean-up + rm ${tmp_dir}/*.nc + done +done + + + + + + diff --git a/downscaling_ap5/preprocess/preprocess_data_era5_to_crea6.py b/downscaling_ap5/preprocess/preprocess_data_era5_to_crea6.py old mode 100644 new mode 100755 diff --git a/downscaling_ap5/utils/other_utils.py b/downscaling_ap5/utils/other_utils.py index 2853df1e88789774b1821e482a81df47f9443ab5..649f2de8db122d8c95d4f1ab6888aa71a12a1c22 100644 --- a/downscaling_ap5/utils/other_utils.py +++ b/downscaling_ap5/utils/other_utils.py @@ -12,24 +12,27 @@ Some auxiliary functions for the project: * subset_files_on_date * extract_date * ensure_datetime + * doy_to_mo * last_day_of_month * flatten * remove_files * check_str_in_list + * shape_from_str * find_closest_divisor - * free_mem +# * free_mem * print_gpu_usage * print_cpu_usage * get_memory_usage * get_max_memory_usage * copy_filelist + * merge_dicts """ # doc-string __author__ = "Michael Langguth" __email__ = "m.langguth@fz-juelich.de" __date__ = "2022-01-20" -__update__ = "2023-03-17" +__update__ = "2023-11-27" import os import gc @@ -170,6 +173,18 @@ def ensure_datetime(date): return date_dt +def doy_to_mo(day_of_year: int, year: int): + """ + Converts day of year to year-month datetime object (e.g. in_day = 2, year = 2017 yields January 2017). + From AtmoRep-project: https://isggit.cs.uni-magdeburg.de/atmorep/atmorep/ + :param day_of_year: day of year + :param year: corresponding year (to reflect leap years) + :return year-month datetime object + """ + date_month = pd.to_datetime(year * 1000 + day_of_year, format='%Y%j') + return date_month + + def last_day_of_month(any_day): """ Returns the last day of a month @@ -249,6 +264,19 @@ def check_str_in_list(list_in: List, str2check: str_or_List, labort: bool = True return stat, [] +def shape_from_str(fname): + """ + Retrieves shapes from AtmoRep output-filenames. + From AtmoRep-project: https://isggit.cs.uni-magdeburg.de/atmorep/atmorep/ + :param fname: filename of AtmoRep output-file + :return shapes inferred from AtmoRep output-file + """ + shapes_str = fname.replace("_phys.dat", ".dat").split(".")[-2].split("_s_")[-1].split("_") #remove .ext and split + shapes = [int(i) for i in shapes_str] + return shapes + + + def find_closest_divisor(n1, div): """ Function to find closest divisor for a given number with respect to a target value @@ -275,16 +303,18 @@ def find_closest_divisor(n1, div): return all_divs[i] -def free_mem(var_list: List): - """ - Delete all variables in var_list and release memory - :param var_list: list of variables to be deleted - """ - var_list = to_list(var_list) - for var in var_list: - del var - - gc.collect() +#def free_mem(var_list: List): +# *** This was found to be in effective, cf. michael_issue085-memory_footprint *** +# +# """ +# Delete all variables in var_list and release memory +# :param var_list: list of variables to be deleted +# """ +# var_list = to_list(var_list) +# for var in var_list: +# del var +# +# gc.collect() # The following auxiliary methods have been adapted from MAELSTROM AP3, # see https://git.ecmwf.int/projects/MLFET/repos/maelstrom-radiation/browse/climetlab_maelstrom_radiation/benchmarks/ @@ -365,4 +395,23 @@ def copy_filelist(file_list: List, dest_dir: str, file_list_dest: List = None ,l if labort: raise FileNotFoundError(f"Could not find file '{f}'. Error will be raised.") else: - print(f"WARNING: Could not find file '{f}'.") \ No newline at end of file + print(f"WARNING: Could not find file '{f}'.") + +def merge_dicts(default_dict, user_dict): + """ + Recursively merge two dictionaries, ensuring that all default keys are set. + """ + merged_dict = default_dict.copy() + + for key, value in user_dict.items(): + if isinstance(value, dict) and key in merged_dict and isinstance(merged_dict[key], dict): + # If the value is a dictionary and the key exists in both dictionaries, + # recursively merge the dictionaries. + merged_dict[key] = merge_dicts(merged_dict[key], value) + else: + # Otherwise, set the value in the merged dictionary. + assert isinstance(value, type(merged_dict[key])), \ + f"Type mismatch for key '{key}': {type(value)} != {type(merged_dict[key])}" + merged_dict[key] = value + + return merged_dict diff --git a/downscaling_ap5/utils/read_atmorep_data.py b/downscaling_ap5/utils/read_atmorep_data.py new file mode 100644 index 0000000000000000000000000000000000000000..ab8646131030f687bbf33eaa574dfd6eb19e3d1c --- /dev/null +++ b/downscaling_ap5/utils/read_atmorep_data.py @@ -0,0 +1,761 @@ +# SPDX-FileCopyrightText: 2023 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) +# +# SPDX-License-Identifier: MIT + +__author__ = "Michael Langguth" +__email__ = "m.langguth@fz-juelich.de" +__date__ = "2023-05-21" +__update__ = "2023-08-15" + +# import modules +import os, sys +sys.path.append("./") +import glob +import json +from typing import List, Tuple +from time import time as timer +from tqdm import tqdm + +import pandas as pd +import xarray as xr +import numpy as np + +from other_utils import doy_to_mo, shape_from_str + +# the class to handle AtmoRep data +class HandleAtmoRepData(object): + """ + Handle outout data of AtmoRep. + TO-DO: + - get dx-info from token-config + - add ibatch-information to coordinates of masked data + - sanity check on data (partly done) + """ + known_data_types = ["source", "prediction", "target", "ensembles"] + + # offsets to get reconstruct correct timestamps + days_offset = 1 + hours_offset = 0 + + def __init__(self, model_id: str, dx_in: float, atmorep_dir: str = "/p/scratch/atmo-rep/results/", + in_dir: str = "/p/scratch/atmo-rep/data/era5/ml_levels/", + target_type: str = "fields_prediction", dx_tar: float = None, + tar_dir: str = None, epsilon: float = 0.001): + """ + :param model_id: ID of Atmorep-run to process + :param dx_in: grid spacing of input data + :param atmorep_dir: Base-directory where AtmoRep-output is located + :param in_dir: Base-directory for input data used to train AtmoRep (to get normalization/correction-data) + :param target_type: Either fields_prediction or fields_target + :param dx_tar: grid spacing of target data + :param tar_dir: Base-directory for target data used to train AtmoRep (to get normalization/correction-data) + :param epsilon: epsilon paramter for log transformation (if applied at all) + """ + self.model_id = model_id if model_id.startswith("id") else f"id{model_id}" + self.datadir = os.path.join(atmorep_dir, self.model_id) + self.datadir_input = in_dir + self.datadir_target = self.datadir_input if tar_dir is None else tar_dir + self.target_type = target_type + self.dx_in = dx_in + self.dx_tar = self.dx_in if dx_tar is None else dx_tar + self.epsilon = epsilon + self.config_file, self.config = self._get_config() + self.input_variables = self._get_invars() + self.target_variables = self._get_tarvars() + + self.input_token_config = self.get_input_token_config() + self.target_token_config = self.get_target_token_config() + + def _get_config(self) -> Tuple[str, dict]: + """ + Get configuration dictionary of trained AtmoRep-model. + """ + config_jsf = os.path.join(self.datadir, f"model_{self.model_id}.json") + with open(config_jsf) as json_file: + config = json.load(json_file) + return config_jsf, config + + def _get_invars(self) -> List: + """ + Get list of input variables of trained AtmoRep-model. + """ + return [var_list[0] for var_list in self.config["fields"]] + #return list(np.asarray(self.config["fields"], dtype=object)[:, 0]) + + def _get_tarvars(self) -> List: + """ + Get list of target variables of trained AtmoRep-model. + """ + return [var_list[0] for var_list in self.config[self.target_type]] + #return list(np.asarray(self.config[self.target_type])[:, 0]) + + def _get_token_config(self, key) -> dict: + """ + Generic function to retrieve token configuration. + :param key: key-string from which token-info is deducable + :return dictionary of token info + """ + token_config_keys = ["general_config", "vlevel", "num_tokens", "token_shape", "bert_parameters"] + token_config = {var[0]: None for var in self.config[key]} + for i, var in enumerate(token_config): + len_config = len(self.config[key][i]) + token_config[var] = {config_key: self.config[key][i][j+1] for j, config_key in enumerate(token_config_keys)} + if len_config >= 7: + token_config[var]["norm_type"] = self.config[key][i][6] + if len_config >= 8: + if isinstance(self.config[key][i][7][-1], bool): + token_config[var]["log_transform"] = self.config[key][i][7][-1] + else: + token_config[var]["log_transform"] = False + else: # default setting for normalization type + token_config[var]["norm_type"] = "global" + + return token_config + + def get_input_token_config(self) -> dict: + """ + Get input token configuration + """ + return self._get_token_config("fields") + + def get_target_token_config(self) -> dict: + """ + Retrieve token configuration of output/target data. + Note that the token configuration is the same as the input data as long as target_fields is unset. + """ + if self.target_type in ["target_fields", "fields_targets"]: + return self._get_token_config(self.target_type) + else: + return self.input_token_config + + def _get_token_info(self, token_type: str, rank: int = 0, epoch: int = 0, batch: int = 0, varname: str = None, lmasked: bool = True): + """ + Retrieve token information. + Note: For BERT-training (i.e. lmasked == True), the tokens are scattered in time and space. + Thus, the token information will be returned as an index-based array, whereas the token information + is returned in a structured manner (i.e. shaped as (nbatch, nvlevel, nt, ny, nx) where nt, ny, nx + represent the number of tokens in (time, lat, lon)-dimension. + :param token_type: Type of token for which info should be retrieved (either 'input', 'target', 'ensemble' or 'prediction') + :param rank: rank (of job) that has created the requested token information file + :param epoch: training epoch of requested token information file + :param batch: batch of requested token information + :param varname: name of variable for which token info is requested + :param lmasked: flag if masking was applied on token (for token_type 'target' or 'prediction' only) + """ + if self.dx_tar != self.dx_in: # in case of different grid spacing in target and source data, target specific token info files are present + add_str_tar = "target" + # ML: Required due to inconsistency in naming convetion for files providing token info + # and masked token indices (see below) + add_str_tar2 = "targets" + else: + add_str_tar, add_str_tar2 = "", "" + var_aux = varname + if token_type == self.known_data_types[0]: + var_aux = self.input_variables[0] if var_aux is None else var_aux + fpatt = os.path.join(self.datadir, f"*rank{rank}_epoch{epoch:05d}_batch{batch:05d}_token_infos_{var_aux}*.dat") + token_dict = self.input_token_config[var_aux] + elif token_type in self.known_data_types[1:]: + var_aux = self.target_variables[0] if var_aux is None else var_aux + fpatt = os.path.join(self.datadir, f"*rank{rank}_epoch{epoch:05d}_batch{batch:05d}_{add_str_tar}_token_infos_{var_aux}*.dat") + token_dict = self.target_token_config[var_aux] + else: + raise ValueError(f"Parsed token type '{token_type}' is unknown. Choose one of the following: {*self.known_data_types,}") + # Get file for token info + fname_tokinfos_now = self.get_token_file(fpatt) + + # read token info and reshape + shape_aux = shape_from_str(fname_tokinfos_now) + tokinfo_data = np.fromfile(fname_tokinfos_now, dtype=np.float32) + # hack for downscaling + if token_type in ["prediction", "target"] and "downscaling_num_layers" in self.config: + token_dict["num_tokens"] = [1, *token_dict["num_tokens"][1::]] + tokinfo_data = tokinfo_data.reshape(shape_aux[0], len(token_dict["vlevel"]), *token_dict["num_tokens"], shape_aux[-1]) + + if token_type in self.known_data_types[1:] and lmasked: + # read indices of masked tokens + fpatt = os.path.join(self.datadir, f"*rank{rank}_epoch{epoch:05d}_batch{batch:05d}_{add_str_tar2}_tokens_masked_idx_{var_aux}*.dat") + fname_tok_maskedidx = self.get_token_file(fpatt) + + tok_masked_idx = np.fromfile(fname_tok_maskedidx, dtype=np.int64) + # reshape token info for slicing... + tokinfo_data = tokinfo_data.transpose(1, 0, 2, 3, 4, 5).reshape( -1, shape_aux[-1]) + # ... and slice to relevant data + tokinfo_data = tokinfo_data[tok_masked_idx] + + return tokinfo_data + + def _get_date_nomask(self, tokinfo_data, nt): + + assert tokinfo_data.ndim == 6, f"Parsed token info data has {tokinfo_data.ndim} dimensions, but 6 are expected." + + times = np.array(np.floor(np.delete(tokinfo_data, np.s_[3:], axis = -1)), dtype = np.int32) # ? + times = times[:,0,:,0,0,:] # remove redundant dimensions + + years, days, hours = times[:,:,0].flatten(), times[:,:,1].flatten(), times[:,:,2].flatten() + years = np.array([int(y) for y in years]) + + # ML: Should not be required in the future!!! + if any(days < 0): + years = np.where(days < 0, years - 1, years) + days_new = np.array([pd.Timestamp(y, 12, 31).dayofyear - self.days_offset for y in years]) + days = np.where(days < 0, days_new - days, days) + + if any(days > 364): + days_per_year = np.array([pd.Timestamp(y, 12, 31).dayofyear for y in years]) - self.days_offset + years = np.where(days > days_per_year, years + 1, years) + days = np.where(days > days_per_year, days - days_per_year - self.days_offset, days) + + # construct date information for centered data position of tokens + dates = doy_to_mo(days + self.days_offset, years) + dates = dates + pd.TimedeltaIndex(hours.flatten(), unit='h') + ### ML: Is this really required??? + # appy hourly offset + dates = dates - pd.TimedeltaIndex(np.ones(dates.shape)*self.hours_offset, unit="h") + # reshape and construct remaining date information of tokens + dates = np.array(dates).reshape(times.shape[0:-1]) + dates = np.array([list(dates + pd.TimedeltaIndex(np.ones(dates.shape)*hh, unit="h")) for hh in range(-int(nt/2), int(nt/2) + 1)]) + dates = dates.transpose(1, 2, 0).reshape(times.shape[0], -1) + + return dates + + # times = np.array(np.floor(np.delete(tokinfo_data, np.s_[3:], axis = -1)), dtype = np.int32) # ? + # times = times[:,0,:,0,0,:] # remove redundant dimensions + + # years, days, hours = times[:,:,0].flatten(), times[:,:,1].flatten(), times[:,:,2].flatten() + # years = np.array([int(y) for y in years]) + + # # ML: Should not be required in the future!!! + # if any(days < 0): + # years = np.where(days < 0, years - 1, years) + # days_new = np.array([pd.Timestamp(y, 12, 31).dayofyear - self.days_offset for y in years]) + # days = np.where(days < 0, days_new - days, days) + + # if any(days > 364): + # days_per_year = np.array([pd.Timestamp(y, 12, 31).dayofyear for y in years]) - self.days_offset + # years = np.where(days > days_per_year, years + 1, years) + # days = np.where(days > days_per_year, days - days_per_year - self.days_offset, days) + + # # construct date information for centered data position of tokens + # dates = doy_to_mo(days + self.days_offset, years) + # dates = dates + pd.TimedeltaIndex(hours.flatten(), unit='h') + # ### ML: Is this really required??? + # dates = dates - pd.TimedeltaIndex(np.ones(dates.shape)*self.hours_offset, unit="h") + + # # reshape and construct remaining date information of tokens + # dates = np.array(dates).reshape(times.shape[0:-1]) + # dates = np.array([dates + pd.TimedeltaIndex(np.ones(dates.shape)*hh, unit="h") for hh in range(-int(nt/2), int(nt/2) + 1)]) + # print(dates.shape) + # print(times.shape[0]) + # dates = dates.transpose(1, 2, 0).reshape(times.shape[0], -1) + + # return dates + + def _get_date_masked(self, tokinfo_data, nt): + + assert tokinfo_data.ndim == 2, f"Parsed token info data has {tokinfo_data.ndim} dimensions, but 2 are expected." + + years, days, hours = tokinfo_data[:, 0].flatten(), tokinfo_data[:, 1].flatten(), tokinfo_data[:, 2].flatten() + days = np.array(np.floor(days), dtype=np.int32) + # ML: Should not be required in the future!!! + if any(days < 0): + years = np.where(days < 0, years - 1, years) + days_new = np.array([pd.Timestamp(y, 12, 31).dayofyear - self.days_offset for y in years]) + days = np.where(days < 0, days_new - days, days) + + if any(days > 364): + days_per_year = np.array([pd.Timestamp(y, 12, 31).dayofyear for y in years]) - self.days_offset + years = np.where(days > days_per_year, years + 1, years) + days = np.where(days > days_per_year, days - days_per_year - self.days_offset, days) + + dates = doy_to_mo(days + self.days_offset, years) # add 1 to days since counting starts with zero in token_info + dates = dates + pd.TimedeltaIndex(hours, unit='h') + ### ML: Is this really required??? + # appy hourly offset + dates = dates - pd.TimedeltaIndex(np.ones(dates.shape)*self.hours_offset, unit="h") + + # reshape and construct remaining date information of tokens + dates = np.array([dates + pd.TimedeltaIndex(np.ones(dates.shape)*hh, unit="h") for hh in range(-int(nt/2), int(nt/2) + 1)]) + dates = dates.transpose() + + return np.array(dates) + + def get_date(self, tokinfo_data, token_config): + """ + Retrieve dates from token info data + :param tokinfo_data: token info data which was read beforehand by _get_token_info-method + :param token_config: corresponding token configuration + """ + nt = token_config["token_shape"][0] + + ndims_tokinfo = tokinfo_data.ndim + + if ndims_tokinfo == 2: + get_date_func = self._get_date_masked + elif ndims_tokinfo == 6: + get_date_func = self._get_date_nomask + else: + raise ValueError(f"Parsed tokinfo_data-array has unexpected number of dimensions ({ndims_tokinfo}).") + + dates = get_date_func(tokinfo_data, nt) + + return dates + + def get_grid(self, tokinfo_data, token_config, dx): + """ + Retrieve underlying geo/grid information. + :param tokinfo_data: token info data which was read beforehand by _get_token_info-method + :param token_config: corresponding token configuration + :param dx: spacing of underlying grid + """ + ndims_tokinfo = tokinfo_data.ndim + + if ndims_tokinfo == 2: + get_grid_func = self._get_grid_masked + elif ndims_tokinfo == 6: + get_grid_func = self._get_grid_nomask + else: + raise ValueError(f"Parsed tokinfo_data-array has unexpected number of dimensions ({ndims_tokinfo}).") + + lats, lons = get_grid_func(tokinfo_data, token_config, dx) + + return lats, lons + + + def read_one_file(self, fname: str, token_type: str, varname: str, token_config: dict, token_info, dx: float, lmasked: bool = True, ldenormalize: bool = True, + no_mean_denorm: bool = False): + """ + Read data from a single output file of AtmoRep and convert to xarray DataArray with underlying coordinate information. + :param token_type: Type of token for which info should be retrieved (either 'input', 'target', 'ensemble' or 'prediction') + :param rank: rank (of job) that has created the requested token information file + :param epoch: training epoch of requested token information file + :param batch: batch of requested token information + :param varname: name of variable for which token info is requested + :param lmasked: flag if masking was applied on token (for token_type 'target' or 'prediction' only) + :param ldenormalize: flag if denormalize/invert correction should be applied (also includes inversion of log transformation) + :param no_mean_denorm: flag if mean should not be added when denormalization is performed + """ + data = np.fromfile(fname, dtype=np.float32) + + times = self.get_date(token_info, token_config) + lats, lons = self.get_grid(token_info, token_config, dx) + + if token_type in self.known_data_types[1:] and lmasked: + data = self._reshape_masked_data(data, token_config, token_type == "ensembles", self.config["net_tail_num_nets"]) + vlvl = np.array(token_info[:, 3], dtype=np.int32) + data = self.masked_data_to_xarray(data, varname, times, vlvl, lats, lons, token_type == "ensembles") + else: + data = self._reshape_nomask_data(data, token_config, self.config["batch_size_test"]) + data = self.nomask_data_to_xarray(data, varname, times, token_config["vlevel"], lats, lons) + + if ldenormalize: + t0 = timer() + + if getattr(token_config, "log_transform", False): + data = self.invert_log_transform(data) + + if token_type in self.known_data_types[1:] and lmasked: + data = self.denormalize_masked_data(data, token_type, token_config["norm_type"], no_mean_denorm) + else: + data = self.denormalize_nomask_data(data, token_type, token_config["norm_type"], no_mean_denorm) + + data.attrs["denormalization time [s]"] = timer() - t0 + + return data + + def read_data(self, token_type: str, varname: str, rank: int = -1, epoch: int = -1, batch: int = -1, lmasked: bool = True, + ldenormalize: bool = True, no_mean_denorm: bool = False): + """ + Read data from a single output file of AtmoRep and convert to xarray DataArray with underlying coordinate information. + :param token_type: Type of token for which info should be retrieved (either 'input', 'target', 'ensemble' or 'prediction') + :param rank: rank (of job) that has created the requested token information file + :param epoch: training epoch of requested token information file + :param batch: batch of requested token information + :param varname: name of variable for which token info is requested + :param lmasked: flag if masking was applied on token (for token_type 'target' or 'prediction' only) + :param ldenormalize: flag if denormalize/invert correction should be applied (also includes inversion of log transformation) + :param no_mean_denorm: flag if mean should not be added when denormalization is performed + """ + if token_type == "source": + token_type_str = "source" + token_config = self.input_token_config[varname] + dx = self.dx_in + elif token_type in self.known_data_types[1:]: + token_type_str = "preds" if token_type == "prediction" else token_type + token_config = self.target_token_config[varname] + dx = self.dx_tar + else: + raise ValueError(f"Parsed token type '{token_type}' is unknown. Choose one of the following: {*self.known_data_types,}") + + filelist = self.get_hierarchical_sorted_files(token_type_str, varname, rank, epoch, batch) + + print(f"Start reading {len(filelist)} files for {token_type} data...") + lwarn = True + for i, f in enumerate(tqdm(filelist)): + rank, epoch, batch = self.get_rank_epoch_batch(f) + try: + token_info = self._get_token_info(token_type, rank=rank, epoch=epoch, batch=batch, varname=varname, lmasked=lmasked) + except FileNotFoundError: + if lwarn: # print warning only once + print(f"No token info for {token_type} data of {varname} found. Proceed with token info for input data.") + lwarn = False + token_info = self._get_token_info(self.known_data_types[0], rank=rank, epoch=epoch, batch=batch, varname=None, lmasked=lmasked) + + da_f = self.read_one_file(f, token_type, varname, token_config, token_info, dx, lmasked, ldenormalize, no_mean_denorm) + + if i == 0: + da_list = [] + denorm_time = 0. + da_list.append(da_f.copy()) + dim_concat = da_f.dims[0] + else: + ilast = da_list[i-1][dim_concat][-1] + 1 + inew = np.arange(ilast, ilast + len(da_f[dim_concat])) + da_f = da_f.assign_coords({dim_concat: inew}) + if ldenormalize: + denorm_time += da_f.attrs["denormalization time [s]"] + + da_list.append(da_f.copy()) + #da = xr.concat([da, da_f], dim=da.dims[0]) + + da = xr.concat(da_list, dim=da_list[0].dims[0]) + if ldenormalize: + da.attrs['denormalization time [s]'] = denorm_time + print(f"Denormalization of {len(filelist)} files for {token_type} data took {da.attrs['denormalization time [s]']:.2f}s") + + return da + + def get_hierarchical_sorted_files(self, token_type_str: str, varname: str, rank: int = -1, epoch: int = -1, batch: int = -1): + rank_str = f"rank*" if rank == -1 else f"rank{rank:d}" + epoch_str = f"epoch*" if epoch == -1 else f"epoch{epoch:05d}" + batch_str = f"batch*" if batch == -1 else f"batch{batch:05d}" + + fpatt = f"*{self.model_id}_{rank_str}_{epoch_str}_{batch_str}_{token_type_str}_{varname}*.dat" + filelist = glob.glob(os.path.join(self.datadir, fpatt)) + + filelist = [f for f in filelist if '000-1' not in f] #remove epoch -1 + + if len(filelist) == 0: + raise FileNotFoundError(f"Could not file any files mathcing pattern '{fpatt}' under directory '{self.datadir}'.") + + # hierarchical sorting: epoch -> rank -> batch + sorted_filelist = sorted(filelist, key=lambda x: self.get_number(x, "_rank")) + sorted_filelist = sorted(sorted_filelist, key=lambda x: self.get_number(x, "_epoch")) + sorted_filelist = sorted(sorted_filelist, key=lambda x: self.get_number(x, "_batch")) + + return sorted_filelist + + def denormalize_global(self, da, param_dir, no_mean = False): + + da_dims = list(da.dims) + varname = da.name + vlvl = da["vlevel"].values #list(da["vlevel"].values)[0] + # get nomralization parameters + mean, std = self.get_global_norm_params(varname, vlvl, param_dir) + if no_mean: + mean[...] = 0. + + # re-index data along time dimension for efficient denormalization + time_dims = list(da["time"].dims) + da = da.stack({"time_aux": time_dims}) + time_save = da["time_aux"] # save information for later re-indexing + da = da.set_index({"time_aux": "time"}).sortby("time_aux") + time_save = time_save.sortby("time") + + da = da.resample({"time_aux": "1M"}) + # loop over year-month items + for i, (ts, da_mm) in enumerate(da): + yr_mm = pd.to_datetime(ts).strftime("%Y-%m") + da_mm = da_mm * std.sel({"year_month": yr_mm}).values + mean.sel({"year_month": yr_mm}).values + if i == 0: + da_concat = da_mm.copy() + else: + da_concat = xr.concat([da_concat, da_mm], dim="time_aux") + + da_concat["time_aux"] = time_save + if not xr.__version__ == "0.20.1": + da_concat = da_concat.reset_index("time_aux") + da_concat = da_concat.unstack("time_aux") + + return da_concat.transpose(*da_dims) + + def denormalize_masked_data(self, data: xr.DataArray, token_type: str, norm_type: str, no_mean: bool = False): + """ + Denormalizes/Inverts correction for masked data. + Data has to normalized considering vertical level and time which both vary along token-dimension. + :param data: normalized (xarray) data array providing masked data (cf. masked_data_to_xarray-method) + :param token_type: type of token to be handled, e.g. 'source' (cf. known_data_types) + :param norm_type: type of normalization applied to the data (either 'local' or 'global') + :param no_mean: flag if data normalization has NOT been zero-meaned + """ + mm_yr = np.unique(data["time"].dt.strftime("y%Y_m%m")) + vlevels = np.unique(data["vlevel"]) + varname = data.name + dim0 = data.dims[0] + basedir = self.datadir_input if token_type == self.known_data_types[0] else self.datadir_target + + for vlvl in vlevels: + if norm_type == "local": + datadir = os.path.join(basedir, f"{vlvl}", "corrections", varname) + for month in mm_yr: + fcorr_now = os.path.join(datadir, f"corrections_mean_var_{varname}_{month}_ml{vlvl}.nc") + norm_data = xr.open_dataset(fcorr_now) + mean, std = norm_data[f"{varname}_ml{vlvl:d}_mean"], norm_data[f"{varname}_ml{vlvl:d}_std"] + if no_mean: mean[...] = 0. + + mask = (data["vlevel"] == vlvl) & (data["time"].dt.strftime("y%Y_m%m") == month) + data_now = data.where(mask, drop=True) + for it in data_now[dim0]: # loop over all tokens + xy_dict = {"lat": data_now.sel({dim0: it})["lat"], "lon": data_now.sel({dim0: it})["lon"]} + mu_it, std_it = mean.sel(xy_dict), std.sel(xy_dict) + data_now.loc[{dim0: it}] = data_now.loc[{dim0: it}]*std_it + mu_it + data = xr.where(mask, data_now, data) + + elif norm_type == "global": + mask = data["vlevel"] == vlvl + data = xr.where(mask, self.denormalize_global(data.where(mask), basedir, no_mean = no_mean), data) + + return data + + def denormalize_nomask_data(self, data: xr.DataArray, token_type:str, norm_type:str, no_mean: bool = False): + """ + Denormalizes/Inverts correction for unmasked data. + Data has to normalized considering vertical level and time which both vary along token-dimension. + :param data: normalized (xarray) data array providing unmasked data (cf. nomask_data_to_xarray-method) + :param token_type: type of token to be handled, e.g. 'source' (cf. known_data_types) + :param norm_type: type of normalization applied to the data (either 'local' or 'global') + :param no_mean: flag if data normalization has NOT been zero-meaned + """ + times = data["time"] + nt = len(times[0,:]) + dim0 = data.dims[0] + varname = data.name + basedir = self.datadir_input if token_type == self.known_data_types[0] else self.datadir_target + + center_times = times.min(dim="t") + pd.Timedelta(nt/2, "hours") + yr_mm = np.unique(center_times.dt.strftime("y%Y_m%m")) + + for vlvl in data["vlevel"]: + if norm_type == "local": + iv = vlvl.values + data_dir = os.path.join(basedir, f"{iv:d}", "corrections", varname) + for month in yr_mm: + fcorr_now = os.path.join(data_dir, f"corrections_mean_var_{varname}_{month}_ml{iv:d}.nc") + norm_data = xr.open_dataset(fcorr_now) + mean, std = norm_data[f"{varname}_ml{iv:d}_mean"], norm_data[f"{varname}_ml{iv:d}_std"] + if no_mean: mean[...] = 0. + + mask = (data["time"].dt.strftime("y%Y_m%m") == month) + data_now = data.where(mask, drop=True) + for it in data_now[dim0]: + xy_dict = {"lat": data_now.sel({dim0: it})["lat"], "lon": data_now.sel({dim0: it})["lon"]} + mu_it, std_it = mean.sel(xy_dict), std.sel(xy_dict) + data.loc[{dim0: it, "vlevel": iv}] = data.loc[{dim0: it, "vlevel": iv}]*std_it + mu_it + + elif norm_type == "global": + data.loc[{"vlevel": vlvl}] = self.denormalize_global(data.sel({"vlevel": vlvl}), basedir, no_mean = no_mean) + + return data + + def invert_log_transform(self, data): + """ + Inverts log transformation on data. + param data: the xarray DataArray which was log transformed + :return: data after inversion of log transformation + """ + data = self.epsilon*(np.exp(data) - 1.) + + return data + + def get_rank_epoch_batch(self, fname, to_int: bool=True): + rank = self.get_number(fname, "_rank") + epoch = self.get_number(fname, "_epoch") + batch = self.get_number(fname, "_batch") + + if to_int: + rank, epoch, batch = int(rank), int(epoch), int(batch) + + return rank, epoch, batch + + @staticmethod + def nomask_data_to_xarray(data_np, varname: str, times, vlvls, lat, lon, lensemble: bool = False): + + if lensemble: + raise ValueError("Ensemlbe not supported yet.") + + nbatch, nt = data_np.shape[0], data_np.shape[2] + nlat, nlon = len(lat[0, :]), len(lon[0, :]) + + da = xr.DataArray(data_np, dims=["ibatch", "vlevel", "t", "y", "x"], + coords={"ibatch": np.arange(nbatch), "vlevel": vlvls, + "t": np.arange(nt), "y": np.arange(nlat), "x": np.arange(nlon), + "time": (["ibatch", "t"], times), + "lat": (["ibatch", "y"], lat), "lon": (["ibatch", "x"], lon)}, + name=varname) + return da + + @staticmethod + def masked_data_to_xarray(data_np, varname: str, times, vlvls, lat, lon, lensemble: bool = False): + + data_dims = ["itoken", "tt", "yt", "xt"] + if lensemble: + ntoken, nens, ntt, yt, xt = data_np.shape + data_dims.insert(1, "ens") + else: + ntoken, ntt, yt, xt = data_np.shape + + data_coords = {"itoken": np.arange(ntoken), "tt": np.arange(ntt), "yt": np.arange(yt), + "xt": np.arange(xt), "time": (["itoken", "tt"], times), + "vlevel": ("itoken", vlvls), + "lat": (["itoken", "yt"], lat), "lon": (["itoken", "xt"], lon)} + + if lensemble: + data_coords["ens"] = np.arange(nens) + + da = xr.DataArray(data_np, dims=data_dims, coords=data_coords, name = varname) + + return da + + @staticmethod + def get_number(file_name, split_arg): + # Extract the number from the file name using the provided split argument + return int(file_name.split(split_arg)[1].split('_')[0]) + + @staticmethod + def get_token_file(fpatt: str, nfiles: int = 1): + # Get file for token info + fnames = glob.glob(fpatt) + # sanity check + if not fnames: + raise FileNotFoundError(f"Could not find required file(-s) using the following filename pattern: '{fpatt}'") + + assert len(fnames) == nfiles, f"More files matching filename pattern '{fpatt}' found than expected." + + if nfiles == 1: + fnames = fnames[0] + + return fnames + + @staticmethod + def _get_grid_nomask(tokinfo_data, token_config, dx): + """ + Retrieve underlying geo/grid information for unmasked data (complete patches!) + :param tokinfo_data: token info data which was read beforehand by _get_token_info-method + :param token_config: corresponding token configuration + :param dx: spacing of underlying grid + """ + # retrieve spatial token size + ny, nx = token_config["token_shape"][1:3] + + # off-centering for even number of grid points + ny1, nx1 = 1, 1 + + if ny%2 == 0: + ny1 = 0 + if nx%2 == 0: + nx1 = 0 + + lat_offset = np.arange(-int((ny-ny1)/2)*dx, int(ny/2+ny1)*dx, dx) + lon_offset = np.arange(-int((nx-nx1)/2)*dx, int(nx/2+nx1)*dx, dx) + + #IMPORTANT: need to swap axes in lats to have ntokens_lat adjacent to tokinfo --> 8x4x8x8 -> 8x8x4x8 + lons = np.array([tokinfo_data.swapaxes(0, -1)[-3]+lon_offset[i] for i in range(len(lon_offset))]).swapaxes(0, -1)%360 + lats = np.array([tokinfo_data.swapaxes(-3, -2).swapaxes(0, -1)[-4]+lat_offset[i] for i in range(len(lat_offset))]).swapaxes(0, -1)%180 + # if(flip_lats): + # lats = np.flip(lats) + + lats, lons = lats[:, 0, 0, 0, :, :], lons[:, 0, 0, 0, :, :] + lats, lons = lats.reshape(lats.shape[0], -1), lons.reshape(lats.shape[0], -1) + + # correct lat values because they are in 'mathematical coordinates' with 0 at the North Pole + lats = 90. -lats + + return lats, lons + + @staticmethod + def _get_grid_masked(tokinfo_data, token_config, dx): + """ + Retrieve underlying geo/grid information for masked data (scattered tokens!) + :param tokinfo_data: token info data which was read beforehand by _get_token_info-method + :param token_config: corresponding token configuration + :param dx: spacing of underlying grid + """ + # retrieve spatial token size + ntok = tokinfo_data.shape[0] + ny, nx = token_config["token_shape"][1:3] + + # off-centering for even number of grid points + ny1, nx1 = 1, 1 + if ny%2 == 0: + ny1 = 0 + if nx%2 == 0: + nx1 = 0 + + lat_offset = np.arange(-int((ny-ny1)/2+nx1)*dx, int(ny/2+ny1)*dx, dx) + lon_offset = np.arange(-int((nx-nx1)/2+ny1)*dx, int(nx/2+nx1)*dx, dx) + + #lats = np.array([tokinfo_data[idx, 4] + np.arange(-int(ny/2)*dx, int(ny/2+1)*dx, dx) for idx in range(ntok)]) % 180 #boundary conditions + #lons = np.array([tokinfo_data[idx, 5] + np.arange(-int(nx/2)*dx, int(nx/2+1)*dx, dx) for idx in range(ntok)]) % 360 #boundary conditions + lats = np.array([tokinfo_data[idx, 4] + lat_offset for idx in range(ntok)]) % 180 #boundary conditions + lons = np.array([tokinfo_data[idx, 5] + lon_offset for idx in range(ntok)]) % 360 #boundary conditions + + # correct lat values because they are in 'mathematical coordinates' with 0 at the North Pole + lats = 90. - lats + + return lats, lons + + + @staticmethod + def get_global_norm_params(varname, vlv, basedir): + """ + Read parameter files for global z-score normalization + :param varname: name of variable + :param vlv: vertical level index + :param basedir: base directory under which correction/parameter files are located + """ + fcorr = os.path.join(basedir, f"{vlv}", "corrections", f"global_corrections_mean_var_{varname}_ml{vlv:d}.bin") + corr_data = np.fromfile(fcorr, dtype="float32").reshape(-1, 4) + + years, months = corr_data[:,0], corr_data[:,1] + + # hack: filter data where years equal to zero + bad_inds = np.nonzero(years == 0) + years, months = np.delete(years, bad_inds), np.delete(months, bad_inds) + mean, var = np.delete(corr_data[:,2], bad_inds), np.delete(corr_data[:,3], bad_inds) + + yr_mm = [pd.to_datetime(f"{int(yr):d}-{int(m):02d}", format="%Y-%m") for yr, m in zip(years, months)] + + mean = xr.DataArray(mean, dims=["year_month"], coords={"year_month": yr_mm}) + var = xr.DataArray(var, dims=["year_month"], coords={"year_month": yr_mm}) + + return mean, var + + @staticmethod + def _reshape_nomask_data(data, token_config: dict, batch_size: int): + """ + Reshape unmasked token data that has been read from .dat-files of AtmoRep (complete patches!). + Data is assumed to cover + Adapted from Ilaria, but now with in-place operations to save memory. + """ + sh = (batch_size, len(token_config["vlevel"]), *token_config["num_tokens"], *token_config["token_shape"]) + data = data.reshape(*sh) + + # further reshaping to collapse token dimensions + # The following is adapted from Ilaria, but now with in-place operations to save memory (original code, see below). + data = np.transpose(data, (0,1,2,5,3,6,4,7)) + data = data.reshape(*data.shape[:-2], -1) + data = data.reshape(*data.shape[:-3], -1, *data.shape[-1:]) + data = data.reshape(*data.shape[:2], -1, *data.shape[4:]) + + return data + + @staticmethod + def _reshape_masked_data(data, token_config, lensemble: bool = False, nens: int = None): + sh0 = (-1,) + if lensemble: + assert nens > 0, f"Invalid nens value passed ({nens}). It must be an integer > 0" + sh0 = (-1, nens) + + token_sh = token_config["token_shape"] + data = data.reshape(*sh0, *token_sh) + + return data