{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d636ee12-e299-485f-b84d-6f35c05fa766",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "import glob\n",
    "sys.path.append(\"../utils/\")\n",
    "import xarray as xr\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from statistical_evaluation import perform_block_bootstrap_metric, avg_metrics, calculate_cond_quantiles, Scores"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8510684d-9374-4e40-bc1c-4d69181c925c",
   "metadata": {},
   "source": [
    "# Evaluation over a smaller (joint) domain\n",
    "\n",
    "The following cells will first merge all forecast files under `indir` into a single netCDF-file.<br>\n",
    "Then the data is sliced to the domain defined by `lonlatbox` and all subsequent evaluation is performed on this smaller domain.<br>\n",
    "The evaluation metrics are then saved to a file under `indir` named `evaluation_metrics_<nlon>x<nlat>.nc` where `nlat` and `nlon` denote the number of grid points/pixels in latitude and longitude direction of the smaller domain, respectively. <br>\n",
    "\n",
    "Thus, first let's define the basic parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "440b15fa-ecd4-4bb4-9100-ede5abb2b04f",
   "metadata": {},
   "outputs": [],
   "source": [
    "indir = \"/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/savp/20210901T090059_gong1_savp_cv12/\"\n",
    "model = \"savp\"\n",
    "# define domain. [3., 24.3, 40.2, 53.1] corresponds to the smallest domain tested in the GMD paper\n",
    "lonlatbox = [3., 24.3, 40.2, 53.1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd759f01-2561-4615-8056-036bdee6e2c7",
   "metadata": {},
   "source": [
    "Next, we perform a first merging step. For computational efficiency, we merge max. 1000 files in the first step.<br>\n",
    "Since the data is not sorted by the dimension `init_time` when querying along the sample index, we sort it before saving to intermediate files.<br>\n",
    "\n",
    "Given that the merging step has already been performed, no further processing is required.<br>\n",
    "If this is not the case, we start with the sample indices between 0 and 999:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6726da3-d774-4eda-89d6-e315a865bb99",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_fname_ndigits(indir, prefix, suffix, n, patt=\"[0-9]\"):\n",
    "    flist = []\n",
    "    for i in range(1, n+1):\n",
    "        fn_search = os.path.join(indir, \"{0}{1}{2}\".format(prefix, i*patt, suffix))\n",
    "        flist = flist + glob.glob(fn_search)\n",
    "    \n",
    "    if len(flist) == 0:\n",
    "        raise FileNotFoundError(\"Could not find any file under '{0}' with prefix '{1}' and suffix '{2}' containing digits.\".format(indir, prefix, suffix))\n",
    "    return flist\n",
    "\n",
    "# get list of files with sample index between 0 and 999.\n",
    "vfp_list = get_fname_ndigits(indir, \"vfp_date_*sample_ind_\", \".nc\", 3)\n",
    "outfile = os.path.join(indir, \"vfp_{0}_forecasts_sample_ind_0_999.nc\".format(model))\n",
    "\n",
    "if not os.path.isfile(outfile):\n",
    "    print(\"File '{0}' does not exist. \\n Start reading data with sample index between 0 and 999 from '{1}'...\".format(outfile, indir))\n",
    "    data_all = xr.open_mfdataset(vfp_list, concat_dim=\"init_time\", combine=\"nested\", decode_cf=True).load()\n",
    "    data_all = data_all.sortby(\"init_time\")\n",
    "    print(\"Data loaded successfully. Save merged data to '{0}'.\".format(outfile))\n",
    "    data_all.to_netcdf(outfile, encoding={'init_time':{'units': \"seconds since 1900-01-01 00:00:00\"}})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c0222c0-386d-44f4-9532-4e824b14828c",
   "metadata": {},
   "source": [
    "Then, we proceed with the rest. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54f4aa3e-3a39-496e-ae97-65f79d9cd598",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in np.arange(1, 9):\n",
    "    outfile = os.path.join(indir, \"vfp_{0}_forecasts_sample_ind_{1:d}000_{1:d}999.nc\".format(model, i))\n",
    "    if not os.path.isfile(outfile):\n",
    "        print(\"File '{0}' does not exist. Start reading data with sample index between {1:d}000 and {1:d}999 from '{2}'...\".format(outfile, i, indir))\n",
    "        data_all = xr.open_mfdataset(os.path.join(indir, \"vfp_date_*sample_ind_{0}???.nc\".format(i)), concat_dim=\"init_time\", combine=\"nested\", decode_cf=True).load()\n",
    "        data_all = data_all.sortby(\"init_time\")\n",
    "        print(\"Data loaded successfully. Save merged data to '{0}'.\".format(outfile))\n",
    "        data_all.to_netcdf(outfile, encoding={'init_time':{'units': \"seconds since 1900-01-01 00:00:00\"}})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdf16158-0ce5-40a3-848d-f574a1b9d622",
   "metadata": {},
   "source": [
    "Still, xarray's `open_mfdataset`-method would not be able to concatenate all data since the `init_time`-dimension is not montonically increasing/decreasing when looping through the files. <br>\n",
    "Thus, we have to merge the data manually.\n",
    "The merged dataset is then saved to separate datafile for later computation.\n",
    "\n",
    "If the data has already been merged, we simply read the data from the corresponding netCDF-file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92f15edf-c23f-4803-b3c5-618305194de5",
   "metadata": {},
   "outputs": [],
   "source": [
    "outfile_all = os.path.join(indir, \"vfp_{0}_forecasts_all.nc\".format(model))\n",
    "\n",
    "if not os.path.isfile(outfile_all):\n",
    "    \n",
    "    print(\"netCDF-file with all forecasts '{0}' does not exist yet. Start merging and sorting all precursor files.\".format(outfile))\n",
    "    all_files = sorted(glob.glob(os.path.join(indir, \"vfp_{0}_forecasts_sample_ind_*.nc\".format(model))))\n",
    "    \n",
    "    if len(all_files) == 0:\n",
    "        raise FileNotFoundError(\"Could not find any precursor files.\")\n",
    "\n",
    "    for i, f in enumerate(all_files):\n",
    "        print(\"Processing file '{0}'\".format(f))\n",
    "        tmp = xr.open_dataset(os.path.join(indir, f)).load()\n",
    "        if i == 0:\n",
    "            all_fcst = tmp.copy()\n",
    "        else:\n",
    "            print(\"Start merging\")\n",
    "            all_fcst = xr.merge([all_fcst, tmp])   \n",
    "\n",
    "    # sort by init_time-dimension...\n",
    "    all_fcst = all_fcst.sortby(\"init_time\")\n",
    "    # ... and save to file\n",
    "    print(\"Finally, write all merged and sorted data to '{0}'.\".format(outfile_all))\n",
    "    all_fcst.to_netcdf(outfile_all)\n",
    "else:\n",
    "    all_fcst = xr.open_dataset(outfile_all).load()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fcf1cb1-ba0d-4262-8e23-12ba44b6e2d0",
   "metadata": {},
   "source": [
    "Now, we slice the dataset to the domain of interest (defined by `lonlatbox`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ede23e56-5be8-48be-b584-0eb8741acbf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_fcst_sl = all_fcst.sel({\"lon\": slice(lonlatbox[0], lonlatbox[1]), \"lat\": slice(lonlatbox[3], lonlatbox[2])}) \n",
    "print(all_fcst_sl)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e21b89c8-57ab-4070-9b4c-ec0fe24c37b9",
   "metadata": {},
   "source": [
    "Next we initialize the function for calculating the MSE and call it to evaluate the ERA5 and persistence forecasts. <br>\n",
    "If you require further evaluation metrics, just expand the cell accordingly, e.g. add the following lines <br>\n",
    "```\n",
    "ssim_func = Scores(\"ssim\", [\"lat\", \"lon\"]).score_func \n",
    "\n",
    "ssim_era5_all = ssim_func(data_fcst=era5_fcst[varname_fcst], data_ref=era5_fcst[varname_ref])\n",
    "ssim_per_all = (data_fcst=era5_fcst[varname_per], data_ref=era5_fcst[varname_ref])\n",
    "```\n",
    "in case you want to evaluate the SSIM as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2b70b80-6b86-4674-b051-6a23aaa821ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_func = Scores(\"mse\", [\"lat\", \"lon\"]).score_func\n",
    "varname_ref, varname_fcst, varname_per = \"2t_ref\", \"2t_{0}_fcst\".format(model), \"2t_persistence_fcst\"\n",
    "\n",
    "mse_model_all = mse_func(data_fcst=all_fcst_sl[varname_fcst], data_ref=all_fcst_sl[varname_ref])\n",
    "mse_per_all = mse_func(data_fcst=all_fcst_sl[varname_per], data_ref=all_fcst_sl[varname_ref])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7745356d-ad44-47b6-9655-8d6db3433b1a",
   "metadata": {},
   "source": [
    "Then, we initialize the data arrays to store the desired evaluation metrics..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b49db031-126c-44b1-b649-4f70587fac89",
   "metadata": {},
   "outputs": [],
   "source": [
    "fcst_hours = all_fcst_sl[\"fcst_hour\"]\n",
    "nhours = len(fcst_hours)\n",
    "nboots=1000\n",
    "\n",
    "mse_model_fcst = xr.DataArray(np.empty(nhours, dtype=object), coords={\"fcst_hour\": fcst_hours}, dims=[\"fcst_hour\"])\n",
    "mse_model_fcst_boot = xr.DataArray(np.empty((nhours, nboots), dtype=object),\n",
    "                                  coords={\"fcst_hour\": fcst_hours, \"iboot\": np.arange(nboots)},\n",
    "                                  dims=[\"fcst_hour\", \"iboot\"])\n",
    "mse_per_fcst = xr.DataArray(np.empty(nhours, dtype=object), coords={\"fcst_hour\": fcst_hours}, dims=[\"fcst_hour\"])\n",
    "mse_per_fcst_boot = xr.DataArray(np.empty((nhours, nboots), dtype=object),\n",
    "                                 coords={\"fcst_hour\": fcst_hours, \"iboot\": np.arange(nboots)},\n",
    "                                 dims=[\"fcst_hour\", \"iboot\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55967405-02d1-46e8-b3c3-8952d0e28bd2",
   "metadata": {},
   "source": [
    "... and populate them by looping over all forecast hours."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5090e71c-f20f-43e6-94f6-71cbd0b6006d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, fh in enumerate(fcst_hours):\n",
    "    mse_model_curr = mse_model_all.sel(fcst_hour=fh)\n",
    "    mse_per_curr = mse_per_all.sel(fcst_hour=fh)\n",
    "    mse_model_fcst[fh-1], mse_per_fcst[fh-1] = mse_model_curr.mean(), mse_per_curr.mean()\n",
    "\n",
    "    mse_model_fcst_boot[i, :] = perform_block_bootstrap_metric(mse_model_curr, \"init_time\", 24*7)\n",
    "    mse_per_fcst_boot[i, :] = perform_block_bootstrap_metric(mse_per_curr, \"init_time\", 24*7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d526324d-5d19-4193-8208-e609d9c65205",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(mse_model_fcst)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b42cd738-b966-4b24-ad13-351d9b88f9e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(mse_model_fcst)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b13c7287-7a8c-4133-bccc-f250bf25dad7",
   "metadata": {},
   "source": [
    "Finally, we put the data arrays into a joint dataset and save the results into the netCDF-file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cd455ee-4749-46dd-8095-9e43744a1563",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create Dataset and save to netCDF-file\n",
    "ds_mse = xr.Dataset({\"2t_{0}_mse_avg\".format(model): mse_model_fcst, \"2t_{0}_mse_bootstrapped\".format(model): mse_model_fcst_boot, \n",
    "                     \"2t_persistence_mse_avg\": mse_per_fcst, \"2t_persistence_mse_bootstrapped\": mse_per_fcst_boot})\n",
    "\n",
    "outfile = os.path.join(indir, \"evaluation_metrics_{0:d}x{1:d}.nc\".format(len(all_fcst_sl[\"lon\"]), len(all_fcst_sl[\"lat\"])))\n",
    "\n",
    "print(\"Save evaluation metrics to '{0}'\".format(outfile))\n",
    "print(ds_mse)\n",
    "ds_mse.to_netcdf(outfile)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69bb9464-6fdb-489b-ba59-0170040144ee",
   "metadata": {},
   "source": [
    "## Done!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PyDeepLearning-1.1",
   "language": "python",
   "name": "pydeeplearning"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}