diff --git a/BLcourse2.3/01_one_dim.ipynb b/BLcourse2.3/01_one_dim.ipynb index 826c08fab29a8ac94ae682dccf9749d9dc38c42c..f1c2cae93c5a85f2ae9208cb07b48adbaf166ba7 100644 --- a/BLcourse2.3/01_one_dim.ipynb +++ b/BLcourse2.3/01_one_dim.ipynb @@ -2,14 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "16bd8a69", + "id": "5ece8acf", "metadata": {}, "source": [ "# Notation\n", "$\\newcommand{\\ve}[1]{\\mathit{\\boldsymbol{#1}}}$\n", "$\\newcommand{\\ma}[1]{\\mathbf{#1}}$\n", - "$\\newcommand{\\pred}[1]{\\widehat{#1}}$\n", + "$\\newcommand{\\pred}[1]{\\rm{#1}}$\n", + "$\\newcommand{\\predve}[1]{\\mathbf{#1}}$\n", "$\\newcommand{\\cov}{\\mathrm{cov}}$\n", + "$\\newcommand{\\test}[1]{#1_*}$\n", + "$\\newcommand{\\testtest}[1]{#1_{**}}$\n", "\n", "Vector $\\ve a\\in\\mathbb R^n$ or $\\mathbb R^{n\\times 1}$, so \"column\" vector.\n", "Matrix $\\ma A\\in\\mathbb R^{n\\times m}$. Design matrix with input vectors $\\ve\n", @@ -23,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "87093c43", + "id": "71f86676", "metadata": {}, "source": [ "# Imports, helpers, setup" @@ -32,12 +35,22 @@ { "cell_type": "code", "execution_count": null, - "id": "69a0af7f", + "id": "8c581079", + "metadata": {}, + "outputs": [], + "source": [ + "##%matplotlib notebook\n", + "##%matplotlib widget\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f89a0c6", "metadata": {}, "outputs": [], "source": [ - "%matplotlib inline\n", - "\n", "import math\n", "from collections import defaultdict\n", "from pprint import pprint\n", @@ -47,43 +60,7 @@ "from matplotlib import pyplot as plt\n", "from matplotlib import is_interactive\n", "\n", - "\n", - "def extract_model_params(model, raw=False) -> dict:\n", - " \"\"\"Helper to convert model.named_parameters() to dict.\n", - "\n", - " With raw=True, use\n", - " foo.bar.raw_param\n", - " else\n", - " foo.bar.param\n", - "\n", - " See https://docs.gpytorch.ai/en/stable/examples/00_Basic_Usage/Hyperparameters.html#Raw-vs-Actual-Parameters\n", - " \"\"\"\n", - " if raw:\n", - " return dict(\n", - " (p_name, p_val.item())\n", - " for p_name, p_val in model.named_parameters()\n", - " )\n", - " else:\n", - " out = dict()\n", - " # p_name = 'covar_module.base_kernel.raw_lengthscale'. Access\n", - " # model.covar_module.base_kernel.lengthscale (w/o the raw_)\n", - " for p_name, p_val in model.named_parameters():\n", - " # Yes, eval() hack. Sorry.\n", - " p_name = p_name.replace(\".raw_\", \".\")\n", - " p_val = eval(f\"model.{p_name}\")\n", - " out[p_name] = p_val.item()\n", - " return out\n", - "\n", - "\n", - "def plot_samples(ax, X_pred, samples, label=None, **kwds):\n", - " plot_kwds = dict(color=\"tab:green\", alpha=0.3)\n", - " plot_kwds.update(kwds)\n", - "\n", - " if label is None:\n", - " ax.plot(X_pred, samples.T, **plot_kwds)\n", - " else:\n", - " ax.plot(X_pred, samples[0, :], **plot_kwds, label=label)\n", - " ax.plot(X_pred, samples[1:, :].T, **plot_kwds, label=\"_\")\n", + "from utils import extract_model_params, plot_samples\n", "\n", "\n", "# Default float32 results in slightly noisy prior samples. Less so with\n", @@ -105,7 +82,7 @@ }, { "cell_type": "markdown", - "id": "eba6d895", + "id": "97e05998", "metadata": { "lines_to_next_cell": 2 }, @@ -122,7 +99,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24e8765e", + "id": "06c7e542", "metadata": {}, "outputs": [], "source": [ @@ -150,7 +127,7 @@ }, { "cell_type": "markdown", - "id": "9b980b1c", + "id": "28539d2a", "metadata": { "lines_to_next_cell": 2 }, @@ -161,10 +138,10 @@ "likelihood. The kernel is the squared exponential kernel with a scaling\n", "factor.\n", "\n", - "$$\\kappa(\\ve x_i, \\ve x_j) = \\sigma_f\\,\\exp\\left(-\\frac{\\lVert\\ve x_i - \\ve x_j\\rVert_2^2}{2\\,\\ell^2}\\right)$$\n", + "$$\\kappa(\\ve x_i, \\ve x_j) = s\\,\\exp\\left(-\\frac{\\lVert\\ve x_i - \\ve x_j\\rVert_2^2}{2\\,\\ell^2}\\right)$$\n", "\n", "This makes two hyper params, namely the length scale $\\ell$ and the scaling\n", - "$\\sigma_f$. The latter is implemented by wrapping the `RBFKernel` with\n", + "$s$. The latter is implemented by wrapping the `RBFKernel` with\n", "`ScaleKernel`.\n", "\n", "In addition, we define a constant mean via `ConstantMean`. Finally we have\n", @@ -172,14 +149,14 @@ "\n", "* $\\ell$ = `model.covar_module.base_kernel.lengthscale`\n", "* $\\sigma_n^2$ = `model.likelihood.noise_covar.noise`\n", - "* $\\sigma_f$ = `model.covar_module.outputscale`\n", + "* $s$ = `model.covar_module.outputscale`\n", "* $m(\\ve x) = c$ = `model.mean_module.constant`" ] }, { "cell_type": "code", "execution_count": null, - "id": "a0a0c12c", + "id": "84c29689", "metadata": { "lines_to_next_cell": 2 }, @@ -203,6 +180,7 @@ " )\n", "\n", " def forward(self, x):\n", + " \"\"\"The prior, defined in terms of the mean and covariance function.\"\"\"\n", " mean_x = self.mean_module(x)\n", " covar_x = self.covar_module(x)\n", " return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n", @@ -215,7 +193,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e2778668", + "id": "ce3d687b", "metadata": {}, "outputs": [], "source": [ @@ -226,7 +204,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eacc9f2b", + "id": "332904ac", "metadata": {}, "outputs": [], "source": [ @@ -237,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f90ce51c", + "id": "5142e002", "metadata": { "lines_to_next_cell": 2 }, @@ -254,22 +232,22 @@ }, { "cell_type": "markdown", - "id": "eb9ed01f", + "id": "2c09f48a", "metadata": {}, "source": [ "# Sample from the GP prior\n", "\n", "We sample a number of functions $f_j, j=1,\\ldots,M$ from the GP prior and\n", - "evaluate them at all $\\ma X$ = `X_pred` points, of which we have $N'=200$. So\n", - "we effectively generate samples from $p(\\pred{\\ve y}|\\ma X) = \\mathcal N(\\ve\n", - "c, \\ma K)$. Each sampled vector $\\pred{\\ve y}\\in\\mathbb R^{N'}$ and the\n", - "covariance (kernel) matrix is $\\ma K\\in\\mathbb R^{N'\\times N'}$." + "evaluate them at all $\\ma X$ = `X_pred` points, of which we have $N=200$. So\n", + "we effectively generate samples from $p(\\predve f|\\ma X) = \\mathcal N(\\ve\n", + "c, \\ma K)$. Each sampled vector $\\predve f\\in\\mathbb R^{N}$ and the\n", + "covariance (kernel) matrix is $\\ma K\\in\\mathbb R^{N\\times N}$." ] }, { "cell_type": "code", "execution_count": null, - "id": "894b7e36", + "id": "9d766f10", "metadata": { "lines_to_next_cell": 2 }, @@ -305,11 +283,11 @@ }, { "cell_type": "markdown", - "id": "33bf94f7", + "id": "e31330a2", "metadata": {}, "source": [ "Let's investigate the samples more closely. A constant mean $\\ve m(\\ma X) =\n", - "\\ve c$ does *not* mean that each sampled vector $\\pred{\\ve y}$'s mean is\n", + "\\ve c$ does *not* mean that each sampled vector $\\predve f$'s mean is\n", "equal to $c$. Instead, we have that at each $\\ve x_i$, the mean of\n", "*all* sampled functions is the same, so $\\frac{1}{M}\\sum_{j=1}^M f_j(\\ve x_i)\n", "\\approx c$ and for $M\\rightarrow\\infty$ it will be exactly $c$.\n" @@ -318,7 +296,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ff0dcbaa", + "id": "e2372d76", "metadata": {}, "outputs": [], "source": [ @@ -332,7 +310,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7018cb26", + "id": "cb1d5d14", "metadata": {}, "outputs": [], "source": [ @@ -346,7 +324,7 @@ }, { "cell_type": "markdown", - "id": "ea005dc2", + "id": "9db57693", "metadata": {}, "source": [ "# Fit GP to data: optimize hyper params\n", @@ -366,10 +344,8 @@ { "cell_type": "code", "execution_count": null, - "id": "67b5961b", - "metadata": { - "lines_to_next_cell": 2 - }, + "id": "58266bb8", + "metadata": {}, "outputs": [], "source": [ "# Train mode\n", @@ -390,8 +366,16 @@ " print(f\"iter {ii+1}/{n_iter}, {loss=:.3f}\")\n", " for p_name, p_val in extract_model_params(model).items():\n", " history[p_name].append(p_val)\n", - " history[\"loss\"].append(loss.item())\n", - "\n", + " history[\"loss\"].append(loss.item())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2cc3092c", + "metadata": {}, + "outputs": [], + "source": [ "# Plot hyper params and loss (neg. log marginal likelihood) convergence\n", "ncols = len(history)\n", "fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5))\n", @@ -401,30 +385,41 @@ " ax.set_xlabel(\"iterations\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "030d1fc8", + "metadata": {}, + "outputs": [], + "source": [ + "# Values of optimized hyper params\n", + "pprint(extract_model_params(model, raw=False))" + ] + }, { "cell_type": "markdown", - "id": "5faec1a4", + "id": "bb0c1621", "metadata": {}, "source": [ "# Run prediction\n", "\n", "We show \"noiseless\" (left: $\\sigma = \\sqrt{\\mathrm{diag}(\\ma\\Sigma)}$) vs.\n", "\"noisy\" (right: $\\sigma = \\sqrt{\\mathrm{diag}(\\ma\\Sigma + \\sigma_n^2\\,\\ma\n", - "I_N)}$) predictions, where $\\ma\\Sigma\\equiv\\cov(\\ve f_*)$ is the posterior\n", - "predictive covariance matrix from R&W 2006 eq. 2.24 with $\\ma K = K(X,X)$,\n", - "$\\ma K'=K(X_*, X)$ and $\\ma K''=K(X_*, X_*)$, so\n", + "I_N)}$) predictions with\n", "\n", - "$$\\ma\\Sigma = \\ma K'' - \\ma K'\\,(\\ma K+\\sigma_n^2\\,\\ma I)^{-1}\\,\\ma K'^\\top$$\n", + "$$\\ma\\Sigma = \\testtest{\\ma K} - \\test{\\ma K}\\,(\\ma K+\\sigma_n^2\\,\\ma I)^{-1}\\,\\test{\\ma K}^\\top$$\n", "\n", - "See\n", - "https://elcorto.github.io/gp_playground/content/gp_pred_comp/notebook_plot.html\n", - "for details." + "We find that $\\ma\\Sigma$ reflects behavior we would like to see from\n", + "epistemic uncertainty -- it is high when we have no data\n", + "(out-of-distribution). But this alone isn't the whole story. We need to add\n", + "the estimated noise level $\\sigma_n^2$ in order for the confidence band to\n", + "cover the data." ] }, { "cell_type": "code", "execution_count": null, - "id": "e56fd041", + "id": "5821ae0f", "metadata": {}, "outputs": [], "source": [ @@ -437,7 +432,10 @@ " post_pred_y = likelihood(model(X_pred))\n", "\n", " fig, axs = plt.subplots(ncols=2, figsize=(12, 5))\n", - " for ii, (ax, post_pred) in enumerate(zip(axs, [post_pred_f, post_pred_y])):\n", + " fig_sigmas, ax_sigmas = plt.subplots()\n", + " for ii, (ax, post_pred, name) in enumerate(\n", + " zip(axs, [post_pred_f, post_pred_y], [\"f\", \"y\"])\n", + " ):\n", " yf_mean = post_pred.mean\n", " yf_samples = post_pred.sample(sample_shape=torch.Size((10,)))\n", "\n", @@ -467,6 +465,23 @@ " color=\"tab:orange\",\n", " alpha=0.3,\n", " )\n", + " if name == \"f\":\n", + " sigma_label = r\"$\\pm 2\\sqrt{\\mathrm{diag}(\\Sigma)}$\"\n", + " zorder = 1\n", + " else:\n", + " sigma_label = (\n", + " r\"$\\pm 2\\sqrt{\\mathrm{diag}(\\Sigma + \\sigma_n^2\\,I)}$\"\n", + " )\n", + " zorder = 0\n", + " ax_sigmas.fill_between(\n", + " X_pred.numpy(),\n", + " lower.numpy(),\n", + " upper.numpy(),\n", + " label=\"confidence \" + sigma_label,\n", + " color=\"tab:orange\" if name == \"f\" else \"tab:blue\",\n", + " alpha=0.5,\n", + " zorder=zorder,\n", + " )\n", " y_min = y_train.min()\n", " y_max = y_train.max()\n", " y_span = y_max - y_min\n", @@ -474,6 +489,7 @@ " plot_samples(ax, X_pred, yf_samples, label=\"posterior pred. samples\")\n", " if ii == 1:\n", " ax.legend()\n", + " ax_sigmas.legend()\n", "\n", "# When running as script\n", "if not is_interactive():\n", diff --git a/BLcourse2.3/02_two_dim.ipynb b/BLcourse2.3/02_two_dim.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4216214ffb5d21cb15bdfe4935cd99f9b581acb8 --- /dev/null +++ b/BLcourse2.3/02_two_dim.ipynb @@ -0,0 +1,595 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "7cc3b0fe", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib notebook\n", + "##%matplotlib widget\n", + "##%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "194927e4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "from pprint import pprint\n", + "\n", + "import torch\n", + "import gpytorch\n", + "from matplotlib import pyplot as plt\n", + "from matplotlib import is_interactive\n", + "from mpl_toolkits.mplot3d import Axes3D\n", + "import numpy as np\n", + "\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "from utils import extract_model_params, plot_samples, fig_ax_3d" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44d9d58d", + "metadata": {}, + "outputs": [], + "source": [ + "torch.set_default_dtype(torch.float64)\n", + "torch.manual_seed(123)" + ] + }, + { + "cell_type": "markdown", + "id": "4aa6e6bd", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "# Generate toy 2D data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d0d40c6", + "metadata": {}, + "outputs": [], + "source": [ + "class MexicanHat:\n", + " def __init__(self, xlim, ylim, nx, ny, mode, **kwds):\n", + " self.xlim = xlim\n", + " self.ylim = ylim\n", + " self.nx = nx\n", + " self.ny = ny\n", + " self.xg, self.yg = self._get_xy_grid()\n", + " self.XG, self.YG = self._get_meshgrids(self.xg, self.yg)\n", + " self.X = self._make_X(mode)\n", + " self.z = self.func(self.X)\n", + "\n", + " def _make_X(self, mode=\"grid\"):\n", + " if mode == \"grid\":\n", + " X = torch.empty((self.nx * self.ny, 2))\n", + " X[:, 0] = self.XG.flatten()\n", + " X[:, 1] = self.YG.flatten()\n", + " elif mode == \"rand\":\n", + " X = torch.rand(self.nx * self.ny, 2)\n", + " X[:, 0] = X[:, 0] * (self.xlim[1] - self.xlim[0]) + self.xlim[0]\n", + " X[:, 1] = X[:, 1] * (self.ylim[1] - self.ylim[0]) + self.ylim[0]\n", + " return X\n", + "\n", + " def _get_xy_grid(self):\n", + " x = torch.linspace(self.xlim[0], self.xlim[1], self.nx)\n", + " y = torch.linspace(self.ylim[0], self.ylim[1], self.ny)\n", + " return x, y\n", + "\n", + " @staticmethod\n", + " def _get_meshgrids(xg, yg):\n", + " return torch.meshgrid(xg, yg, indexing=\"ij\")\n", + "\n", + " @staticmethod\n", + " def func(X):\n", + " r = torch.sqrt((X**2).sum(axis=1))\n", + " return torch.sin(r) / r\n", + "\n", + " @staticmethod\n", + " def n2t(x):\n", + " return torch.from_numpy(x)\n", + "\n", + " def apply_scalers(self, x_scaler, y_scaler):\n", + " self.X = self.n2t(x_scaler.transform(self.X))\n", + " Xtmp = x_scaler.transform(torch.stack((self.xg, self.yg), dim=1))\n", + " self.XG, self.YG = self._get_meshgrids(\n", + " self.n2t(Xtmp[:, 0]), self.n2t(Xtmp[:, 1])\n", + " )\n", + " self.z = self.n2t(y_scaler.transform(self.z[:, None])[:, 0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87f84642", + "metadata": {}, + "outputs": [], + "source": [ + "data_train = MexicanHat(\n", + " xlim=[-15, 25], ylim=[-15, 5], nx=20, ny=20, mode=\"rand\"\n", + ")\n", + "x_scaler = StandardScaler().fit(data_train.X)\n", + "y_scaler = StandardScaler().fit(data_train.z[:, None])\n", + "data_train.apply_scalers(x_scaler, y_scaler)\n", + "\n", + "data_pred = MexicanHat(\n", + " xlim=[-15, 25], ylim=[-15, 5], nx=100, ny=100, mode=\"grid\"\n", + ")\n", + "data_pred.apply_scalers(x_scaler, y_scaler)\n", + "\n", + "# train inputs\n", + "X_train = data_train.X\n", + "\n", + "# inputs for prediction and plotting\n", + "X_pred = data_pred.X" + ] + }, + { + "cell_type": "markdown", + "id": "158c3da2", + "metadata": {}, + "source": [ + "# Exercise 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb9e4ee1", + "metadata": {}, + "outputs": [], + "source": [ + "use_noise = False\n", + "use_gap = False" + ] + }, + { + "cell_type": "markdown", + "id": "197f7480", + "metadata": {}, + "source": [ + "# Exercise 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f1ba561", + "metadata": {}, + "outputs": [], + "source": [ + "##use_noise = True\n", + "##use_gap = False" + ] + }, + { + "cell_type": "markdown", + "id": "56b080da", + "metadata": {}, + "source": [ + "# Exercise 3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77e9779e", + "metadata": {}, + "outputs": [], + "source": [ + "##use_noise = False\n", + "##use_gap = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3f0b5d7", + "metadata": {}, + "outputs": [], + "source": [ + "if use_noise:\n", + " # noisy train data\n", + " noise_std = 0.2\n", + " noise_dist = torch.distributions.Normal(loc=0, scale=noise_std)\n", + " y_train = data_train.z + noise_dist.sample_n(len(data_train.z))\n", + "else:\n", + " # noise-free train data\n", + " noise_std = 0\n", + " y_train = data_train.z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f6b3548", + "metadata": {}, + "outputs": [], + "source": [ + "# Cut out part of the train data to create out-of-distribution predictions\n", + "\n", + "if use_gap:\n", + " mask = (X_train[:, 0] > 0) & (X_train[:, 1] < 0)\n", + " X_train = X_train[~mask, :]\n", + " y_train = y_train[~mask]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30dbecca", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = fig_ax_3d()\n", + "s0 = ax.plot_surface(\n", + " data_pred.XG,\n", + " data_pred.YG,\n", + " data_pred.z.reshape((data_pred.nx, data_pred.ny)),\n", + " color=\"tab:grey\",\n", + " alpha=0.5,\n", + ")\n", + "s1 = ax.scatter(\n", + " xs=X_train[:, 0],\n", + " ys=X_train[:, 1],\n", + " zs=y_train,\n", + " color=\"tab:blue\",\n", + " alpha=0.5,\n", + ")\n", + "ax.set_xlabel(\"X_0\")\n", + "ax.set_ylabel(\"X_1\")" + ] + }, + { + "cell_type": "markdown", + "id": "56d27ef1", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "# Define GP model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6da6f37a", + "metadata": {}, + "outputs": [], + "source": [ + "class ExactGPModel(gpytorch.models.ExactGP):\n", + " \"\"\"API:\n", + "\n", + " model.forward() prior f_pred\n", + " model() posterior f_pred\n", + "\n", + " likelihood(model.forward()) prior with noise y_pred\n", + " likelihood(model()) posterior with noise y_pred\n", + " \"\"\"\n", + "\n", + " def __init__(self, X_train, y_train, likelihood):\n", + " super().__init__(X_train, y_train, likelihood)\n", + " self.mean_module = gpytorch.means.ConstantMean()\n", + " self.covar_module = gpytorch.kernels.ScaleKernel(\n", + " gpytorch.kernels.RBFKernel()\n", + " )\n", + "\n", + " def forward(self, x):\n", + " \"\"\"The prior, defined in terms of the mean and covariance function.\"\"\"\n", + " mean_x = self.mean_module(x)\n", + " covar_x = self.covar_module(x)\n", + " return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n", + "\n", + "\n", + "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n", + "model = ExactGPModel(X_train, y_train, likelihood)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50695462", + "metadata": {}, + "outputs": [], + "source": [ + "# Inspect the model\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "393a458d", + "metadata": {}, + "outputs": [], + "source": [ + "# Default start hyper params\n", + "pprint(extract_model_params(model, raw=False))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95ce5bf2", + "metadata": {}, + "outputs": [], + "source": [ + "# Set new start hyper params\n", + "model.mean_module.constant = 0.0\n", + "model.covar_module.base_kernel.lengthscale = 3.0\n", + "model.covar_module.outputscale = 8.0\n", + "model.likelihood.noise_covar.noise = 0.1\n", + "\n", + "pprint(extract_model_params(model, raw=False))" + ] + }, + { + "cell_type": "markdown", + "id": "56ccbafa", + "metadata": {}, + "source": [ + "# Fit GP to data: optimize hyper params" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d44f3667", + "metadata": {}, + "outputs": [], + "source": [ + "# Train mode\n", + "model.train()\n", + "likelihood.train()\n", + "\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.2)\n", + "loss_func = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n", + "\n", + "n_iter = 300\n", + "history = defaultdict(list)\n", + "for ii in range(n_iter):\n", + " optimizer.zero_grad()\n", + " loss = -loss_func(model(X_train), y_train)\n", + " loss.backward()\n", + " optimizer.step()\n", + " if (ii + 1) % 10 == 0:\n", + " print(f\"iter {ii+1}/{n_iter}, {loss=:.3f}\")\n", + " for p_name, p_val in extract_model_params(model).items():\n", + " history[p_name].append(p_val)\n", + " history[\"loss\"].append(loss.item())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1ff9c36", + "metadata": {}, + "outputs": [], + "source": [ + "ncols = len(history)\n", + "fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5))\n", + "for ax, (p_name, p_lst) in zip(axs, history.items()):\n", + " ax.plot(p_lst)\n", + " ax.set_title(p_name)\n", + " ax.set_xlabel(\"iterations\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6c9c1b2", + "metadata": {}, + "outputs": [], + "source": [ + "# Values of optimized hyper params\n", + "pprint(extract_model_params(model, raw=False))" + ] + }, + { + "cell_type": "markdown", + "id": "b01dac44", + "metadata": {}, + "source": [ + "# Run prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5bb6926d", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "model.eval()\n", + "likelihood.eval()\n", + "\n", + "with torch.no_grad():\n", + " post_pred_f = model(X_pred)\n", + " post_pred_y = likelihood(model(X_pred))\n", + "\n", + " fig = plt.figure()\n", + " ax = fig.add_subplot(projection=\"3d\")\n", + " ax.plot_surface(\n", + " data_pred.XG,\n", + " data_pred.YG,\n", + " data_pred.z.reshape((data_pred.nx, data_pred.ny)),\n", + " color=\"tab:grey\",\n", + " alpha=0.5,\n", + " )\n", + " ax.plot_surface(\n", + " data_pred.XG,\n", + " data_pred.YG,\n", + " post_pred_y.mean.reshape((data_pred.nx, data_pred.ny)),\n", + " color=\"tab:red\",\n", + " alpha=0.5,\n", + " )\n", + " ax.set_xlabel(\"X_0\")\n", + " ax.set_ylabel(\"X_1\")\n", + "\n", + "assert (post_pred_f.mean == post_pred_y.mean).all()" + ] + }, + { + "cell_type": "markdown", + "id": "9a354f60", + "metadata": {}, + "source": [ + "# Plot difference to ground truth and uncertainty" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3570739", + "metadata": {}, + "outputs": [], + "source": [ + "ncols = 3\n", + "fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 7, 5))\n", + "\n", + "vmax = post_pred_y.stddev.max()\n", + "cs = []\n", + "\n", + "cs.append(\n", + " axs[0].contourf(\n", + " data_pred.XG,\n", + " data_pred.YG,\n", + " torch.abs(post_pred_y.mean - data_pred.z).reshape(\n", + " (data_pred.nx, data_pred.ny)\n", + " ),\n", + " )\n", + ")\n", + "axs[0].set_title(\"|y_pred - y_true|\")\n", + "\n", + "cs.append(\n", + " axs[1].contourf(\n", + " data_pred.XG,\n", + " data_pred.YG,\n", + " post_pred_f.stddev.reshape((data_pred.nx, data_pred.ny)),\n", + " vmin=0,\n", + " vmax=vmax,\n", + " )\n", + ")\n", + "axs[1].set_title(\"f_std (epistemic)\")\n", + "\n", + "cs.append(\n", + " axs[2].contourf(\n", + " data_pred.XG,\n", + " data_pred.YG,\n", + " post_pred_y.stddev.reshape((data_pred.nx, data_pred.ny)),\n", + " vmin=0,\n", + " vmax=vmax,\n", + " )\n", + ")\n", + "axs[2].set_title(\"y_std (epistemic + aleatoric)\")\n", + "\n", + "for ax, c in zip(axs, cs):\n", + " ax.set_xlabel(\"X_0\")\n", + " ax.set_ylabel(\"X_1\")\n", + " ax.scatter(x=X_train[:, 0], y=X_train[:, 1], color=\"white\", alpha=0.2)\n", + " fig.colorbar(c, ax=ax)" + ] + }, + { + "cell_type": "markdown", + "id": "9bfc86ee", + "metadata": {}, + "source": [ + "# Check learned noise" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ad91b3b", + "metadata": {}, + "outputs": [], + "source": [ + "print((post_pred_y.stddev**2 - post_pred_f.stddev**2).mean().sqrt())\n", + "print(\n", + " np.sqrt(\n", + " extract_model_params(model, raw=False)[\"likelihood.noise_covar.noise\"]\n", + " )\n", + ")\n", + "print(noise_std)" + ] + }, + { + "cell_type": "markdown", + "id": "18e5a5ee", + "metadata": {}, + "source": [ + "# Plot confidence bands" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46c9aab5", + "metadata": {}, + "outputs": [], + "source": [ + "y_mean = post_pred_y.mean.reshape((data_pred.nx, data_pred.ny))\n", + "y_std = post_pred_y.stddev.reshape((data_pred.nx, data_pred.ny))\n", + "upper = y_mean + 2 * y_std\n", + "lower = y_mean - 2 * y_std\n", + "\n", + "fig, ax = fig_ax_3d()\n", + "for Z, color in [(upper, \"tab:green\"), (lower, \"tab:red\")]:\n", + " ax.plot_surface(\n", + " data_pred.XG,\n", + " data_pred.YG,\n", + " Z,\n", + " color=color,\n", + " alpha=0.5,\n", + " )\n", + "\n", + "contour_z = lower.min() - 1\n", + "zlim = ax.get_xlim()\n", + "ax.set_zlim((contour_z, zlim[1] + abs(contour_z)))\n", + "ax.contourf(data_pred.XG, data_pred.YG, y_std, zdir=\"z\", offset=contour_z)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81788c9d", + "metadata": {}, + "outputs": [], + "source": [ + "# When running as script\n", + "if not is_interactive():\n", + " plt.show()" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,py:light" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}