{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get Dataset from request"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime as dt\n",
    "from pathlib import Path\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from toargridding.grids import RegularGrid\n",
    "from toargridding.toar_rest_client import (\n",
    "    AnalysisServiceDownload,\n",
    "    STATION_LAT,\n",
    "    STATION_LON,\n",
    ")\n",
    "from toargridding.metadata import Metadata, TimeSample, AnalysisRequestResult, Coordinates\n",
    "from toargridding.variables import Coordinate\n",
    "\n",
    "\n",
    "endpoint = \"https://toar-data.fz-juelich.de/api/v2/analysis/statistics/\"\n",
    "toargridding_base_path = Path(\"/home/simon/Projects/toar/toargridding/\")\n",
    "cache_dir = toargridding_base_path / \"tests\" / \"results\"\n",
    "data_download_dir = toargridding_base_path / \"tests\" / \"data\"\n",
    "\n",
    "analysis_service = AnalysisServiceDownload(endpoint, cache_dir, data_download_dir)\n",
    "my_grid = RegularGrid(1.9, 2.5)\n",
    "\n",
    "time = TimeSample(dt(2016,1,1), dt(2016,12,31), \"daily\")\n",
    "metadata = Metadata.construct(\"mole_fraction_of_ozone_in_air\", \"mean\", time)\n",
    "\n",
    "with open(\"data/daily_2010-01-01_2011-01-01.zip\", \"r+b\") as sample_file:\n",
    "    response_content = sample_file.read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = analysis_service.get_data(metadata)\n",
    "ds = my_grid.as_xarray(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visual inspection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cartopy.crs as ccrs\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as mticker\n",
    "\n",
    "mean_data = ds[\"mean\"]\n",
    "clean_coords = analysis_service.get_clean_coords(timeseries_metadata)\n",
    "all_na = timeseries.isna().all(axis=1)\n",
    "clean_coords = all_na.to_frame().join(clean_coords)[[\"latitude\", \"longitude\"]]\n",
    "all_na_coords = clean_coords[all_na]\n",
    "not_na_coords = clean_coords[~all_na]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "\n",
    "\n",
    "def plot_cells(data, stations, na_stations, discrete=True, plot_stations=False):\n",
    "    fig = plt.figure(figsize=(9, 18))\n",
    "\n",
    "    ax = plt.axes(projection=ccrs.PlateCarree())\n",
    "    ax.coastlines()\n",
    "    gl = ax.gridlines(draw_labels=True)\n",
    "    gl.top_labels = False\n",
    "    gl.left_labels = False\n",
    "    gl.xlocator = mticker.FixedLocator(data.longitude.values)\n",
    "    gl.ylocator = mticker.FixedLocator(data.latitude.values)\n",
    "\n",
    "    cmap = mpl.cm.viridis\n",
    "\n",
    "    if discrete:\n",
    "        print(np.unique(data.values))\n",
    "        bounds = np.arange(8)\n",
    "        norm = mpl.colors.BoundaryNorm(bounds, cmap.N, extend=\"both\")\n",
    "        ticks = np.arange(bounds.size + 1)[:-1] + 0.5\n",
    "        ticklables = bounds\n",
    "        \n",
    "        im = plt.pcolormesh(\n",
    "            data.longitude,\n",
    "            data.latitude,\n",
    "            data,\n",
    "            transform=ccrs.PlateCarree(),\n",
    "            cmap=cmap,\n",
    "            shading=\"nearest\",\n",
    "            norm=norm,\n",
    "        )\n",
    "        cb = fig.colorbar(im, ax=ax, shrink=0.2, aspect=25)\n",
    "        cb.set_ticks(ticks)\n",
    "        cb.set_ticklabels(ticklables)\n",
    "        im = plt.pcolormesh(\n",
    "            data.longitude,\n",
    "            data.latitude,\n",
    "            data,\n",
    "            transform=ccrs.PlateCarree(),\n",
    "            cmap=cmap,\n",
    "            shading=\"nearest\",\n",
    "            norm=norm,\n",
    "        )\n",
    "    else:\n",
    "        im = plt.pcolormesh(\n",
    "            data.longitude,\n",
    "            data.latitude,\n",
    "            data,\n",
    "            transform=ccrs.PlateCarree(),\n",
    "            cmap=cmap,\n",
    "            shading=\"nearest\",\n",
    "        )\n",
    "\n",
    "        cb = fig.colorbar(im, ax=ax, shrink=0.2, aspect=25)\n",
    "    \n",
    "\n",
    "    if plot_stations:\n",
    "        plt.scatter(na_stations[STATION_LON], na_stations[STATION_LAT], s=1, c=\"k\")\n",
    "        plt.scatter(stations[STATION_LON], stations[STATION_LAT], s=1, c=\"r\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "    plt.title(f\"global ozon at {data.time.values}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestep = 2\n",
    "time = ds.time[timestep]\n",
    "data = ds.sel(time=time)\n",
    "\n",
    "plot_cells(data[\"mean\"], not_na_coords, all_na_coords, discrete=False)\n",
    "plt.show()\n",
    "\n",
    "plot_cells(data[\"n\"], not_na_coords, all_na_coords, discrete=True)\n",
    "plt.show()\n",
    "\n",
    "n_observations = ds[\"n\"].sum([\"latitude\", \"longitude\"])\n",
    "plt.plot(ds.time, n_observations)\n",
    "print(np.unique(ds[\"n\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "toargridding-g-KQ1Hyq-py3.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}