diff --git a/BLcourse2.3/01_one_dim.ipynb b/BLcourse2.3/01_one_dim.ipynb
index 826c08fab29a8ac94ae682dccf9749d9dc38c42c..f1c2cae93c5a85f2ae9208cb07b48adbaf166ba7 100644
--- a/BLcourse2.3/01_one_dim.ipynb
+++ b/BLcourse2.3/01_one_dim.ipynb
@@ -2,14 +2,17 @@
  "cells": [
   {
    "cell_type": "markdown",
-   "id": "16bd8a69",
+   "id": "5ece8acf",
    "metadata": {},
    "source": [
     "# Notation\n",
     "$\\newcommand{\\ve}[1]{\\mathit{\\boldsymbol{#1}}}$\n",
     "$\\newcommand{\\ma}[1]{\\mathbf{#1}}$\n",
-    "$\\newcommand{\\pred}[1]{\\widehat{#1}}$\n",
+    "$\\newcommand{\\pred}[1]{\\rm{#1}}$\n",
+    "$\\newcommand{\\predve}[1]{\\mathbf{#1}}$\n",
     "$\\newcommand{\\cov}{\\mathrm{cov}}$\n",
+    "$\\newcommand{\\test}[1]{#1_*}$\n",
+    "$\\newcommand{\\testtest}[1]{#1_{**}}$\n",
     "\n",
     "Vector $\\ve a\\in\\mathbb R^n$ or $\\mathbb R^{n\\times 1}$, so \"column\" vector.\n",
     "Matrix $\\ma A\\in\\mathbb R^{n\\times m}$. Design matrix with input vectors $\\ve\n",
@@ -23,7 +26,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "87093c43",
+   "id": "71f86676",
    "metadata": {},
    "source": [
     "# Imports, helpers, setup"
@@ -32,12 +35,22 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "69a0af7f",
+   "id": "8c581079",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "##%matplotlib notebook\n",
+    "##%matplotlib widget\n",
+    "%matplotlib inline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2f89a0c6",
    "metadata": {},
    "outputs": [],
    "source": [
-    "%matplotlib inline\n",
-    "\n",
     "import math\n",
     "from collections import defaultdict\n",
     "from pprint import pprint\n",
@@ -47,43 +60,7 @@
     "from matplotlib import pyplot as plt\n",
     "from matplotlib import is_interactive\n",
     "\n",
-    "\n",
-    "def extract_model_params(model, raw=False) -> dict:\n",
-    "    \"\"\"Helper to convert model.named_parameters() to dict.\n",
-    "\n",
-    "    With raw=True, use\n",
-    "        foo.bar.raw_param\n",
-    "    else\n",
-    "        foo.bar.param\n",
-    "\n",
-    "    See https://docs.gpytorch.ai/en/stable/examples/00_Basic_Usage/Hyperparameters.html#Raw-vs-Actual-Parameters\n",
-    "    \"\"\"\n",
-    "    if raw:\n",
-    "        return dict(\n",
-    "            (p_name, p_val.item())\n",
-    "            for p_name, p_val in model.named_parameters()\n",
-    "        )\n",
-    "    else:\n",
-    "        out = dict()\n",
-    "        # p_name = 'covar_module.base_kernel.raw_lengthscale'. Access\n",
-    "        # model.covar_module.base_kernel.lengthscale (w/o the raw_)\n",
-    "        for p_name, p_val in model.named_parameters():\n",
-    "            # Yes, eval() hack. Sorry.\n",
-    "            p_name = p_name.replace(\".raw_\", \".\")\n",
-    "            p_val = eval(f\"model.{p_name}\")\n",
-    "            out[p_name] = p_val.item()\n",
-    "        return out\n",
-    "\n",
-    "\n",
-    "def plot_samples(ax, X_pred, samples, label=None, **kwds):\n",
-    "    plot_kwds = dict(color=\"tab:green\", alpha=0.3)\n",
-    "    plot_kwds.update(kwds)\n",
-    "\n",
-    "    if label is None:\n",
-    "        ax.plot(X_pred, samples.T, **plot_kwds)\n",
-    "    else:\n",
-    "        ax.plot(X_pred, samples[0, :], **plot_kwds, label=label)\n",
-    "        ax.plot(X_pred, samples[1:, :].T, **plot_kwds, label=\"_\")\n",
+    "from utils import extract_model_params, plot_samples\n",
     "\n",
     "\n",
     "# Default float32 results in slightly noisy prior samples. Less so with\n",
@@ -105,7 +82,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "eba6d895",
+   "id": "97e05998",
    "metadata": {
     "lines_to_next_cell": 2
    },
@@ -122,7 +99,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "24e8765e",
+   "id": "06c7e542",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -150,7 +127,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "9b980b1c",
+   "id": "28539d2a",
    "metadata": {
     "lines_to_next_cell": 2
    },
@@ -161,10 +138,10 @@
     "likelihood. The kernel is the squared exponential kernel with a scaling\n",
     "factor.\n",
     "\n",
-    "$$\\kappa(\\ve x_i, \\ve x_j) = \\sigma_f\\,\\exp\\left(-\\frac{\\lVert\\ve x_i - \\ve x_j\\rVert_2^2}{2\\,\\ell^2}\\right)$$\n",
+    "$$\\kappa(\\ve x_i, \\ve x_j) = s\\,\\exp\\left(-\\frac{\\lVert\\ve x_i - \\ve x_j\\rVert_2^2}{2\\,\\ell^2}\\right)$$\n",
     "\n",
     "This makes two hyper params, namely the length scale $\\ell$ and the scaling\n",
-    "$\\sigma_f$. The latter is implemented by wrapping the `RBFKernel` with\n",
+    "$s$. The latter is implemented by wrapping the `RBFKernel` with\n",
     "`ScaleKernel`.\n",
     "\n",
     "In addition, we define a constant mean via `ConstantMean`. Finally we have\n",
@@ -172,14 +149,14 @@
     "\n",
     "* $\\ell$ = `model.covar_module.base_kernel.lengthscale`\n",
     "* $\\sigma_n^2$ = `model.likelihood.noise_covar.noise`\n",
-    "* $\\sigma_f$ = `model.covar_module.outputscale`\n",
+    "* $s$ = `model.covar_module.outputscale`\n",
     "* $m(\\ve x) = c$ = `model.mean_module.constant`"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "a0a0c12c",
+   "id": "84c29689",
    "metadata": {
     "lines_to_next_cell": 2
    },
@@ -203,6 +180,7 @@
     "        )\n",
     "\n",
     "    def forward(self, x):\n",
+    "        \"\"\"The prior, defined in terms of the mean and covariance function.\"\"\"\n",
     "        mean_x = self.mean_module(x)\n",
     "        covar_x = self.covar_module(x)\n",
     "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
@@ -215,7 +193,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "e2778668",
+   "id": "ce3d687b",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -226,7 +204,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "eacc9f2b",
+   "id": "332904ac",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -237,7 +215,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "f90ce51c",
+   "id": "5142e002",
    "metadata": {
     "lines_to_next_cell": 2
    },
@@ -254,22 +232,22 @@
   },
   {
    "cell_type": "markdown",
-   "id": "eb9ed01f",
+   "id": "2c09f48a",
    "metadata": {},
    "source": [
     "# Sample from the GP prior\n",
     "\n",
     "We sample a number of functions $f_j, j=1,\\ldots,M$ from the GP prior and\n",
-    "evaluate them at all $\\ma X$ = `X_pred` points, of which we have $N'=200$. So\n",
-    "we effectively generate samples from $p(\\pred{\\ve y}|\\ma X) = \\mathcal N(\\ve\n",
-    "c, \\ma K)$. Each sampled vector $\\pred{\\ve y}\\in\\mathbb R^{N'}$ and the\n",
-    "covariance (kernel) matrix is $\\ma K\\in\\mathbb R^{N'\\times N'}$."
+    "evaluate them at all $\\ma X$ = `X_pred` points, of which we have $N=200$. So\n",
+    "we effectively generate samples from $p(\\predve f|\\ma X) = \\mathcal N(\\ve\n",
+    "c, \\ma K)$. Each sampled vector $\\predve f\\in\\mathbb R^{N}$ and the\n",
+    "covariance (kernel) matrix is $\\ma K\\in\\mathbb R^{N\\times N}$."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "894b7e36",
+   "id": "9d766f10",
    "metadata": {
     "lines_to_next_cell": 2
    },
@@ -305,11 +283,11 @@
   },
   {
    "cell_type": "markdown",
-   "id": "33bf94f7",
+   "id": "e31330a2",
    "metadata": {},
    "source": [
     "Let's investigate the samples more closely. A constant mean $\\ve m(\\ma X) =\n",
-    "\\ve c$ does *not* mean that each sampled vector $\\pred{\\ve y}$'s mean is\n",
+    "\\ve c$ does *not* mean that each sampled vector $\\predve f$'s mean is\n",
     "equal to $c$. Instead, we have that at each $\\ve x_i$, the mean of\n",
     "*all* sampled functions is the same, so $\\frac{1}{M}\\sum_{j=1}^M f_j(\\ve x_i)\n",
     "\\approx c$ and for $M\\rightarrow\\infty$ it will be exactly $c$.\n"
@@ -318,7 +296,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ff0dcbaa",
+   "id": "e2372d76",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -332,7 +310,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "7018cb26",
+   "id": "cb1d5d14",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -346,7 +324,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "ea005dc2",
+   "id": "9db57693",
    "metadata": {},
    "source": [
     "# Fit GP to data: optimize hyper params\n",
@@ -366,10 +344,8 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "67b5961b",
-   "metadata": {
-    "lines_to_next_cell": 2
-   },
+   "id": "58266bb8",
+   "metadata": {},
    "outputs": [],
    "source": [
     "# Train mode\n",
@@ -390,8 +366,16 @@
     "        print(f\"iter {ii+1}/{n_iter}, {loss=:.3f}\")\n",
     "    for p_name, p_val in extract_model_params(model).items():\n",
     "        history[p_name].append(p_val)\n",
-    "    history[\"loss\"].append(loss.item())\n",
-    "\n",
+    "    history[\"loss\"].append(loss.item())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2cc3092c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
     "# Plot hyper params and loss (neg. log marginal likelihood) convergence\n",
     "ncols = len(history)\n",
     "fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5))\n",
@@ -401,30 +385,41 @@
     "    ax.set_xlabel(\"iterations\")"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "030d1fc8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Values of optimized hyper params\n",
+    "pprint(extract_model_params(model, raw=False))"
+   ]
+  },
   {
    "cell_type": "markdown",
-   "id": "5faec1a4",
+   "id": "bb0c1621",
    "metadata": {},
    "source": [
     "# Run prediction\n",
     "\n",
     "We show \"noiseless\" (left: $\\sigma = \\sqrt{\\mathrm{diag}(\\ma\\Sigma)}$) vs.\n",
     "\"noisy\" (right: $\\sigma = \\sqrt{\\mathrm{diag}(\\ma\\Sigma + \\sigma_n^2\\,\\ma\n",
-    "I_N)}$) predictions, where $\\ma\\Sigma\\equiv\\cov(\\ve f_*)$ is the posterior\n",
-    "predictive covariance matrix from R&W 2006 eq. 2.24 with $\\ma K = K(X,X)$,\n",
-    "$\\ma K'=K(X_*, X)$ and $\\ma K''=K(X_*, X_*)$, so\n",
+    "I_N)}$) predictions with\n",
     "\n",
-    "$$\\ma\\Sigma = \\ma K'' - \\ma K'\\,(\\ma K+\\sigma_n^2\\,\\ma I)^{-1}\\,\\ma K'^\\top$$\n",
+    "$$\\ma\\Sigma = \\testtest{\\ma K} - \\test{\\ma K}\\,(\\ma K+\\sigma_n^2\\,\\ma I)^{-1}\\,\\test{\\ma K}^\\top$$\n",
     "\n",
-    "See\n",
-    "https://elcorto.github.io/gp_playground/content/gp_pred_comp/notebook_plot.html\n",
-    "for details."
+    "We find that $\\ma\\Sigma$ reflects behavior we would like to see from\n",
+    "epistemic uncertainty -- it is high when we have no data\n",
+    "(out-of-distribution). But this alone isn't the whole story. We need to add\n",
+    "the estimated noise level $\\sigma_n^2$ in order for the confidence band to\n",
+    "cover the data."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "e56fd041",
+   "id": "5821ae0f",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -437,7 +432,10 @@
     "    post_pred_y = likelihood(model(X_pred))\n",
     "\n",
     "    fig, axs = plt.subplots(ncols=2, figsize=(12, 5))\n",
-    "    for ii, (ax, post_pred) in enumerate(zip(axs, [post_pred_f, post_pred_y])):\n",
+    "    fig_sigmas, ax_sigmas = plt.subplots()\n",
+    "    for ii, (ax, post_pred, name) in enumerate(\n",
+    "        zip(axs, [post_pred_f, post_pred_y], [\"f\", \"y\"])\n",
+    "    ):\n",
     "        yf_mean = post_pred.mean\n",
     "        yf_samples = post_pred.sample(sample_shape=torch.Size((10,)))\n",
     "\n",
@@ -467,6 +465,23 @@
     "            color=\"tab:orange\",\n",
     "            alpha=0.3,\n",
     "        )\n",
+    "        if name == \"f\":\n",
+    "            sigma_label = r\"$\\pm 2\\sqrt{\\mathrm{diag}(\\Sigma)}$\"\n",
+    "            zorder = 1\n",
+    "        else:\n",
+    "            sigma_label = (\n",
+    "                r\"$\\pm 2\\sqrt{\\mathrm{diag}(\\Sigma + \\sigma_n^2\\,I)}$\"\n",
+    "            )\n",
+    "            zorder = 0\n",
+    "        ax_sigmas.fill_between(\n",
+    "            X_pred.numpy(),\n",
+    "            lower.numpy(),\n",
+    "            upper.numpy(),\n",
+    "            label=\"confidence \" + sigma_label,\n",
+    "            color=\"tab:orange\" if name == \"f\" else \"tab:blue\",\n",
+    "            alpha=0.5,\n",
+    "            zorder=zorder,\n",
+    "        )\n",
     "        y_min = y_train.min()\n",
     "        y_max = y_train.max()\n",
     "        y_span = y_max - y_min\n",
@@ -474,6 +489,7 @@
     "        plot_samples(ax, X_pred, yf_samples, label=\"posterior pred. samples\")\n",
     "        if ii == 1:\n",
     "            ax.legend()\n",
+    "    ax_sigmas.legend()\n",
     "\n",
     "# When running as script\n",
     "if not is_interactive():\n",
diff --git a/BLcourse2.3/02_two_dim.ipynb b/BLcourse2.3/02_two_dim.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..4216214ffb5d21cb15bdfe4935cd99f9b581acb8
--- /dev/null
+++ b/BLcourse2.3/02_two_dim.ipynb
@@ -0,0 +1,595 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7cc3b0fe",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%matplotlib notebook\n",
+    "##%matplotlib widget\n",
+    "##%matplotlib inline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "194927e4",
+   "metadata": {
+    "lines_to_next_cell": 2
+   },
+   "outputs": [],
+   "source": [
+    "from collections import defaultdict\n",
+    "from pprint import pprint\n",
+    "\n",
+    "import torch\n",
+    "import gpytorch\n",
+    "from matplotlib import pyplot as plt\n",
+    "from matplotlib import is_interactive\n",
+    "from mpl_toolkits.mplot3d import Axes3D\n",
+    "import numpy as np\n",
+    "\n",
+    "from sklearn.preprocessing import StandardScaler\n",
+    "\n",
+    "from utils import extract_model_params, plot_samples, fig_ax_3d"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "44d9d58d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "torch.set_default_dtype(torch.float64)\n",
+    "torch.manual_seed(123)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "4aa6e6bd",
+   "metadata": {
+    "lines_to_next_cell": 2
+   },
+   "source": [
+    "# Generate toy 2D data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3d0d40c6",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class MexicanHat:\n",
+    "    def __init__(self, xlim, ylim, nx, ny, mode, **kwds):\n",
+    "        self.xlim = xlim\n",
+    "        self.ylim = ylim\n",
+    "        self.nx = nx\n",
+    "        self.ny = ny\n",
+    "        self.xg, self.yg = self._get_xy_grid()\n",
+    "        self.XG, self.YG = self._get_meshgrids(self.xg, self.yg)\n",
+    "        self.X = self._make_X(mode)\n",
+    "        self.z = self.func(self.X)\n",
+    "\n",
+    "    def _make_X(self, mode=\"grid\"):\n",
+    "        if mode == \"grid\":\n",
+    "            X = torch.empty((self.nx * self.ny, 2))\n",
+    "            X[:, 0] = self.XG.flatten()\n",
+    "            X[:, 1] = self.YG.flatten()\n",
+    "        elif mode == \"rand\":\n",
+    "            X = torch.rand(self.nx * self.ny, 2)\n",
+    "            X[:, 0] = X[:, 0] * (self.xlim[1] - self.xlim[0]) + self.xlim[0]\n",
+    "            X[:, 1] = X[:, 1] * (self.ylim[1] - self.ylim[0]) + self.ylim[0]\n",
+    "        return X\n",
+    "\n",
+    "    def _get_xy_grid(self):\n",
+    "        x = torch.linspace(self.xlim[0], self.xlim[1], self.nx)\n",
+    "        y = torch.linspace(self.ylim[0], self.ylim[1], self.ny)\n",
+    "        return x, y\n",
+    "\n",
+    "    @staticmethod\n",
+    "    def _get_meshgrids(xg, yg):\n",
+    "        return torch.meshgrid(xg, yg, indexing=\"ij\")\n",
+    "\n",
+    "    @staticmethod\n",
+    "    def func(X):\n",
+    "        r = torch.sqrt((X**2).sum(axis=1))\n",
+    "        return torch.sin(r) / r\n",
+    "\n",
+    "    @staticmethod\n",
+    "    def n2t(x):\n",
+    "        return torch.from_numpy(x)\n",
+    "\n",
+    "    def apply_scalers(self, x_scaler, y_scaler):\n",
+    "        self.X = self.n2t(x_scaler.transform(self.X))\n",
+    "        Xtmp = x_scaler.transform(torch.stack((self.xg, self.yg), dim=1))\n",
+    "        self.XG, self.YG = self._get_meshgrids(\n",
+    "            self.n2t(Xtmp[:, 0]), self.n2t(Xtmp[:, 1])\n",
+    "        )\n",
+    "        self.z = self.n2t(y_scaler.transform(self.z[:, None])[:, 0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "87f84642",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_train = MexicanHat(\n",
+    "    xlim=[-15, 25], ylim=[-15, 5], nx=20, ny=20, mode=\"rand\"\n",
+    ")\n",
+    "x_scaler = StandardScaler().fit(data_train.X)\n",
+    "y_scaler = StandardScaler().fit(data_train.z[:, None])\n",
+    "data_train.apply_scalers(x_scaler, y_scaler)\n",
+    "\n",
+    "data_pred = MexicanHat(\n",
+    "    xlim=[-15, 25], ylim=[-15, 5], nx=100, ny=100, mode=\"grid\"\n",
+    ")\n",
+    "data_pred.apply_scalers(x_scaler, y_scaler)\n",
+    "\n",
+    "# train inputs\n",
+    "X_train = data_train.X\n",
+    "\n",
+    "# inputs for prediction and plotting\n",
+    "X_pred = data_pred.X"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "158c3da2",
+   "metadata": {},
+   "source": [
+    "# Exercise 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "eb9e4ee1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "use_noise = False\n",
+    "use_gap = False"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "197f7480",
+   "metadata": {},
+   "source": [
+    "# Exercise 2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3f1ba561",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "##use_noise = True\n",
+    "##use_gap = False"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "56b080da",
+   "metadata": {},
+   "source": [
+    "# Exercise 3"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "77e9779e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "##use_noise = False\n",
+    "##use_gap = True"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d3f0b5d7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if use_noise:\n",
+    "    # noisy train data\n",
+    "    noise_std = 0.2\n",
+    "    noise_dist = torch.distributions.Normal(loc=0, scale=noise_std)\n",
+    "    y_train = data_train.z + noise_dist.sample_n(len(data_train.z))\n",
+    "else:\n",
+    "    # noise-free train data\n",
+    "    noise_std = 0\n",
+    "    y_train = data_train.z"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6f6b3548",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Cut out part of the train data to create out-of-distribution predictions\n",
+    "\n",
+    "if use_gap:\n",
+    "    mask = (X_train[:, 0] > 0) & (X_train[:, 1] < 0)\n",
+    "    X_train = X_train[~mask, :]\n",
+    "    y_train = y_train[~mask]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "30dbecca",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fig, ax = fig_ax_3d()\n",
+    "s0 = ax.plot_surface(\n",
+    "    data_pred.XG,\n",
+    "    data_pred.YG,\n",
+    "    data_pred.z.reshape((data_pred.nx, data_pred.ny)),\n",
+    "    color=\"tab:grey\",\n",
+    "    alpha=0.5,\n",
+    ")\n",
+    "s1 = ax.scatter(\n",
+    "    xs=X_train[:, 0],\n",
+    "    ys=X_train[:, 1],\n",
+    "    zs=y_train,\n",
+    "    color=\"tab:blue\",\n",
+    "    alpha=0.5,\n",
+    ")\n",
+    "ax.set_xlabel(\"X_0\")\n",
+    "ax.set_ylabel(\"X_1\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "56d27ef1",
+   "metadata": {
+    "lines_to_next_cell": 2
+   },
+   "source": [
+    "# Define GP model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6da6f37a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class ExactGPModel(gpytorch.models.ExactGP):\n",
+    "    \"\"\"API:\n",
+    "\n",
+    "    model.forward()             prior                   f_pred\n",
+    "    model()                     posterior               f_pred\n",
+    "\n",
+    "    likelihood(model.forward()) prior with noise        y_pred\n",
+    "    likelihood(model())         posterior with noise    y_pred\n",
+    "    \"\"\"\n",
+    "\n",
+    "    def __init__(self, X_train, y_train, likelihood):\n",
+    "        super().__init__(X_train, y_train, likelihood)\n",
+    "        self.mean_module = gpytorch.means.ConstantMean()\n",
+    "        self.covar_module = gpytorch.kernels.ScaleKernel(\n",
+    "            gpytorch.kernels.RBFKernel()\n",
+    "        )\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        \"\"\"The prior, defined in terms of the mean and covariance function.\"\"\"\n",
+    "        mean_x = self.mean_module(x)\n",
+    "        covar_x = self.covar_module(x)\n",
+    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
+    "\n",
+    "\n",
+    "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
+    "model = ExactGPModel(X_train, y_train, likelihood)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "50695462",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Inspect the model\n",
+    "print(model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "393a458d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Default start hyper params\n",
+    "pprint(extract_model_params(model, raw=False))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "95ce5bf2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Set new start hyper params\n",
+    "model.mean_module.constant = 0.0\n",
+    "model.covar_module.base_kernel.lengthscale = 3.0\n",
+    "model.covar_module.outputscale = 8.0\n",
+    "model.likelihood.noise_covar.noise = 0.1\n",
+    "\n",
+    "pprint(extract_model_params(model, raw=False))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "56ccbafa",
+   "metadata": {},
+   "source": [
+    "# Fit GP to data: optimize hyper params"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d44f3667",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Train mode\n",
+    "model.train()\n",
+    "likelihood.train()\n",
+    "\n",
+    "optimizer = torch.optim.Adam(model.parameters(), lr=0.2)\n",
+    "loss_func = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
+    "\n",
+    "n_iter = 300\n",
+    "history = defaultdict(list)\n",
+    "for ii in range(n_iter):\n",
+    "    optimizer.zero_grad()\n",
+    "    loss = -loss_func(model(X_train), y_train)\n",
+    "    loss.backward()\n",
+    "    optimizer.step()\n",
+    "    if (ii + 1) % 10 == 0:\n",
+    "        print(f\"iter {ii+1}/{n_iter}, {loss=:.3f}\")\n",
+    "    for p_name, p_val in extract_model_params(model).items():\n",
+    "        history[p_name].append(p_val)\n",
+    "    history[\"loss\"].append(loss.item())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d1ff9c36",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ncols = len(history)\n",
+    "fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5))\n",
+    "for ax, (p_name, p_lst) in zip(axs, history.items()):\n",
+    "    ax.plot(p_lst)\n",
+    "    ax.set_title(p_name)\n",
+    "    ax.set_xlabel(\"iterations\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c6c9c1b2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Values of optimized hyper params\n",
+    "pprint(extract_model_params(model, raw=False))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b01dac44",
+   "metadata": {},
+   "source": [
+    "# Run prediction"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5bb6926d",
+   "metadata": {
+    "lines_to_next_cell": 2
+   },
+   "outputs": [],
+   "source": [
+    "model.eval()\n",
+    "likelihood.eval()\n",
+    "\n",
+    "with torch.no_grad():\n",
+    "    post_pred_f = model(X_pred)\n",
+    "    post_pred_y = likelihood(model(X_pred))\n",
+    "\n",
+    "    fig = plt.figure()\n",
+    "    ax = fig.add_subplot(projection=\"3d\")\n",
+    "    ax.plot_surface(\n",
+    "        data_pred.XG,\n",
+    "        data_pred.YG,\n",
+    "        data_pred.z.reshape((data_pred.nx, data_pred.ny)),\n",
+    "        color=\"tab:grey\",\n",
+    "        alpha=0.5,\n",
+    "    )\n",
+    "    ax.plot_surface(\n",
+    "        data_pred.XG,\n",
+    "        data_pred.YG,\n",
+    "        post_pred_y.mean.reshape((data_pred.nx, data_pred.ny)),\n",
+    "        color=\"tab:red\",\n",
+    "        alpha=0.5,\n",
+    "    )\n",
+    "    ax.set_xlabel(\"X_0\")\n",
+    "    ax.set_ylabel(\"X_1\")\n",
+    "\n",
+    "assert (post_pred_f.mean == post_pred_y.mean).all()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "9a354f60",
+   "metadata": {},
+   "source": [
+    "# Plot difference to ground truth and uncertainty"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f3570739",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ncols = 3\n",
+    "fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 7, 5))\n",
+    "\n",
+    "vmax = post_pred_y.stddev.max()\n",
+    "cs = []\n",
+    "\n",
+    "cs.append(\n",
+    "    axs[0].contourf(\n",
+    "        data_pred.XG,\n",
+    "        data_pred.YG,\n",
+    "        torch.abs(post_pred_y.mean - data_pred.z).reshape(\n",
+    "            (data_pred.nx, data_pred.ny)\n",
+    "        ),\n",
+    "    )\n",
+    ")\n",
+    "axs[0].set_title(\"|y_pred - y_true|\")\n",
+    "\n",
+    "cs.append(\n",
+    "    axs[1].contourf(\n",
+    "        data_pred.XG,\n",
+    "        data_pred.YG,\n",
+    "        post_pred_f.stddev.reshape((data_pred.nx, data_pred.ny)),\n",
+    "        vmin=0,\n",
+    "        vmax=vmax,\n",
+    "    )\n",
+    ")\n",
+    "axs[1].set_title(\"f_std (epistemic)\")\n",
+    "\n",
+    "cs.append(\n",
+    "    axs[2].contourf(\n",
+    "        data_pred.XG,\n",
+    "        data_pred.YG,\n",
+    "        post_pred_y.stddev.reshape((data_pred.nx, data_pred.ny)),\n",
+    "        vmin=0,\n",
+    "        vmax=vmax,\n",
+    "    )\n",
+    ")\n",
+    "axs[2].set_title(\"y_std (epistemic + aleatoric)\")\n",
+    "\n",
+    "for ax, c in zip(axs, cs):\n",
+    "    ax.set_xlabel(\"X_0\")\n",
+    "    ax.set_ylabel(\"X_1\")\n",
+    "    ax.scatter(x=X_train[:, 0], y=X_train[:, 1], color=\"white\", alpha=0.2)\n",
+    "    fig.colorbar(c, ax=ax)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "9bfc86ee",
+   "metadata": {},
+   "source": [
+    "# Check learned noise"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6ad91b3b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print((post_pred_y.stddev**2 - post_pred_f.stddev**2).mean().sqrt())\n",
+    "print(\n",
+    "    np.sqrt(\n",
+    "        extract_model_params(model, raw=False)[\"likelihood.noise_covar.noise\"]\n",
+    "    )\n",
+    ")\n",
+    "print(noise_std)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "18e5a5ee",
+   "metadata": {},
+   "source": [
+    "# Plot confidence bands"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "46c9aab5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "y_mean = post_pred_y.mean.reshape((data_pred.nx, data_pred.ny))\n",
+    "y_std = post_pred_y.stddev.reshape((data_pred.nx, data_pred.ny))\n",
+    "upper = y_mean + 2 * y_std\n",
+    "lower = y_mean - 2 * y_std\n",
+    "\n",
+    "fig, ax = fig_ax_3d()\n",
+    "for Z, color in [(upper, \"tab:green\"), (lower, \"tab:red\")]:\n",
+    "    ax.plot_surface(\n",
+    "        data_pred.XG,\n",
+    "        data_pred.YG,\n",
+    "        Z,\n",
+    "        color=color,\n",
+    "        alpha=0.5,\n",
+    "    )\n",
+    "\n",
+    "contour_z = lower.min() - 1\n",
+    "zlim = ax.get_xlim()\n",
+    "ax.set_zlim((contour_z, zlim[1] + abs(contour_z)))\n",
+    "ax.contourf(data_pred.XG, data_pred.YG, y_std, zdir=\"z\", offset=contour_z)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "81788c9d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# When running as script\n",
+    "if not is_interactive():\n",
+    "    plt.show()"
+   ]
+  }
+ ],
+ "metadata": {
+  "jupytext": {
+   "formats": "ipynb,py:light"
+  },
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}