diff --git a/BLcourse2.3/03_one_dim_SVI.ipynb b/BLcourse2.3/03_one_dim_SVI.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..cdb7789d7ffdad559216fcb5701592637467479f --- /dev/null +++ b/BLcourse2.3/03_one_dim_SVI.ipynb @@ -0,0 +1,536 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b146bc52", + "metadata": {}, + "source": [ + "# About\n", + "\n", + "In this notebook, we replace the ExactGP inference and log marginal\n", + "likelihood optimization by using sparse stochastic variational inference.\n", + "This serves as an example of the many methods `gpytorch` offers to make GPs\n", + "scale to large data sets.\n", + "$\\newcommand{\\ve}[1]{\\mathit{\\boldsymbol{#1}}}$\n", + "$\\newcommand{\\ma}[1]{\\mathbf{#1}}$\n", + "$\\newcommand{\\pred}[1]{\\rm{#1}}$\n", + "$\\newcommand{\\predve}[1]{\\mathbf{#1}}$\n", + "$\\newcommand{\\test}[1]{#1_*}$\n", + "$\\newcommand{\\testtest}[1]{#1_{**}}$\n", + "$\\DeclareMathOperator{\\diag}{diag}$\n", + "$\\DeclareMathOperator{\\cov}{cov}$" + ] + }, + { + "cell_type": "markdown", + "id": "f96e3304", + "metadata": {}, + "source": [ + "# Imports, helpers, setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "193b2f13", + "metadata": {}, + "outputs": [], + "source": [ + "##%matplotlib notebook\n", + "%matplotlib widget\n", + "##%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c11c1030", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "from collections import defaultdict\n", + "from pprint import pprint\n", + "\n", + "import torch\n", + "import gpytorch\n", + "from matplotlib import pyplot as plt\n", + "from matplotlib import is_interactive\n", + "import numpy as np\n", + "from torch.utils.data import TensorDataset, DataLoader\n", + "\n", + "from utils import extract_model_params, plot_samples\n", + "\n", + "\n", + "torch.set_default_dtype(torch.float64)\n", + "torch.manual_seed(123)" + ] + }, + { + "cell_type": "markdown", + "id": "460d4c3d", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "# Generate toy 1D data\n", + "\n", + "Now we generate 10x more points as in the ExactGP case, still the inference\n", + "won't be much slower (exact GPs scale roughly as $N^3$). Note that the data we\n", + "use here is still tiny (1000 points is easy even for exact GPs), so the\n", + "method's usefulness cannot be fully exploited in our example\n", + "-- also we don't even use a GPU yet :)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1a9647d", + "metadata": {}, + "outputs": [], + "source": [ + "def ground_truth(x, const):\n", + " return torch.sin(x) * torch.exp(-0.2 * x) + const\n", + "\n", + "\n", + "def generate_data(x, gaps=[[1, 3]], const=None, noise_std=None):\n", + " noise_dist = torch.distributions.Normal(loc=0, scale=noise_std)\n", + " y = ground_truth(x, const=const) + noise_dist.sample(\n", + " sample_shape=(len(x),)\n", + " )\n", + " msk = torch.tensor([True] * len(x))\n", + " if gaps is not None:\n", + " for g in gaps:\n", + " msk = msk & ~((x > g[0]) & (x < g[1]))\n", + " return x[msk], y[msk], y\n", + "\n", + "\n", + "const = 5.0\n", + "noise_std = 0.1\n", + "x = torch.linspace(0, 4 * math.pi, 1000)\n", + "X_train, y_train, y_gt_train = generate_data(\n", + " x, gaps=[[6, 10]], const=const, noise_std=noise_std\n", + ")\n", + "X_pred = torch.linspace(\n", + " X_train[0] - 2, X_train[-1] + 2, 200, requires_grad=False\n", + ")\n", + "y_gt_pred = ground_truth(X_pred, const=const)\n", + "\n", + "print(f\"{X_train.shape=}\")\n", + "print(f\"{y_train.shape=}\")\n", + "print(f\"{X_pred.shape=}\")\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.scatter(X_train, y_train, marker=\"o\", color=\"tab:blue\", label=\"noisy data\")\n", + "ax.plot(X_pred, y_gt_pred, ls=\"--\", color=\"k\", label=\"ground truth\")\n", + "ax.legend()" + ] + }, + { + "cell_type": "markdown", + "id": "4ad60c5e", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "# Define GP model\n", + "\n", + "The model follows [this\n", + "example](https://docs.gpytorch.ai/en/stable/examples/04_Variational_and_Approximate_GPs/SVGP_Regression_CUDA.html)\n", + "based on [Hensman et al., \"Scalable Variational Gaussian Process Classification\",\n", + "2015](https://proceedings.mlr.press/v38/hensman15.html). The model is\n", + "\"sparse\" since it works with a set of *inducing* points $(\\ma Z, \\ve u),\n", + "\\ve u=f(\\ma Z)$ which is much smaller than the train data $(\\ma X, \\ve y)$.\n", + "\n", + "We have the same hyper parameters as before\n", + "\n", + "* $\\ell$ = `model.covar_module.base_kernel.lengthscale`\n", + "* $\\sigma_n^2$ = `likelihood.noise_covar.noise`\n", + "* $s$ = `model.covar_module.outputscale`\n", + "* $m(\\ve x) = c$ = `model.mean_module.constant`\n", + "\n", + "plus additional ones, introduced by the approximations used:\n", + "\n", + "* the learnable inducing points $\\ma Z$ for the variational distribution\n", + " $q_{\\ve\\psi}(\\ve u)$\n", + "* learnable parameters of the variational distribution $q_{\\ve\\psi}(\\ve u)=\\mathcal N(\\ve m_u, \\ma S)$: the\n", + " variational mean $\\ve m_u$ and covariance $\\ma S$ in form a lower triangular\n", + " matrix $\\ma L$ such that $\\ma S=\\ma L\\,\\ma L^\\top$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "444929a3", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "class ApproxGPModel(gpytorch.models.ApproximateGP):\n", + " def __init__(self, Z):\n", + " # Approximate inducing value posterior q(u), u = f(Z), Z = inducing\n", + " # points (subset of X_train)\n", + " variational_distribution = (\n", + " gpytorch.variational.CholeskyVariationalDistribution(Z.size(0))\n", + " )\n", + " # Compute q(f(X)) from q(u)\n", + " variational_strategy = gpytorch.variational.VariationalStrategy(\n", + " self,\n", + " Z,\n", + " variational_distribution,\n", + " learn_inducing_locations=True,\n", + " )\n", + " super().__init__(variational_strategy)\n", + " self.mean_module = gpytorch.means.ConstantMean()\n", + " self.covar_module = gpytorch.kernels.ScaleKernel(\n", + " gpytorch.kernels.RBFKernel()\n", + " )\n", + "\n", + " def forward(self, x):\n", + " mean_x = self.mean_module(x)\n", + " covar_x = self.covar_module(x)\n", + " return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n", + "\n", + "\n", + "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n", + "\n", + "n_train = len(X_train)\n", + "# Start values for inducing points Z, use 10% random sub-sample of X_train.\n", + "ind_points_fraction = 0.1\n", + "ind_idxs = torch.randperm(n_train)[: int(n_train * ind_points_fraction)]\n", + "model = ApproxGPModel(Z=X_train[ind_idxs])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5861163d", + "metadata": {}, + "outputs": [], + "source": [ + "# Inspect the model\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98032457", + "metadata": {}, + "outputs": [], + "source": [ + "# Inspect the likelihood. In contrast to ExactGP, the likelihood is not part of\n", + "# the GP model instance.\n", + "print(likelihood)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ec97687", + "metadata": {}, + "outputs": [], + "source": [ + "# Default start hyper params\n", + "print(\"model params:\")\n", + "pprint(extract_model_params(model))\n", + "print(\"likelihood params:\")\n", + "pprint(extract_model_params(likelihood))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "167d584c", + "metadata": {}, + "outputs": [], + "source": [ + "# Set new start hyper params\n", + "model.mean_module.constant = 3.0\n", + "model.covar_module.base_kernel.lengthscale = 1.0\n", + "model.covar_module.outputscale = 1.0\n", + "likelihood.noise_covar.noise = 0.3" + ] + }, + { + "cell_type": "markdown", + "id": "6704ce2f", + "metadata": {}, + "source": [ + "# Fit GP to data: optimize hyper params\n", + "\n", + "Now we optimize the GP hyper parameters by doing a GP-specific variational inference (VI),\n", + "where we optimize not the log marginal likelihood (ExactGP case),\n", + "but an ELBO (evidence lower bound) objective\n", + "\n", + "$$\n", + "\\max_\\ve\\zeta\\left(\\mathbb E_{q_{\\ve\\psi}(\\ve u)}\\left[\\ln p(\\ve y|\\ve u) \\right] -\n", + "D_{\\text{KL}}(q_{\\ve\\psi}(\\ve u)\\Vert p(\\ve u))\\right)\n", + "$$\n", + "\n", + "with respect to\n", + "\n", + "$$\\ve\\zeta = [\\ell, \\sigma_n^2, s, c, \\ve\\psi] $$\n", + "\n", + "with\n", + "\n", + "$$\\ve\\psi = [\\ve m_u, \\ma Z, \\ma L]$$\n", + "\n", + "the parameters of the variational distribution $q_{\\ve\\psi}(\\ve u)$.\n", + "\n", + "In addition, we perform a stochastic\n", + "optimization by using a deep learning type mini-batch loop, hence\n", + "\"stochastic\" variational inference (SVI). The latter speeds up the\n", + "optimization since we only look at a fraction of data per optimizer step to\n", + "calculate an approximate loss gradient (`loss.backward()`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f39d86f8", + "metadata": {}, + "outputs": [], + "source": [ + "# Train mode\n", + "model.train()\n", + "likelihood.train()\n", + "\n", + "optimizer = torch.optim.Adam(\n", + " [dict(params=model.parameters()), dict(params=likelihood.parameters())],\n", + " lr=0.1,\n", + ")\n", + "loss_func = gpytorch.mlls.VariationalELBO(\n", + " likelihood, model, num_data=X_train.shape[0]\n", + ")\n", + "\n", + "train_dl = DataLoader(\n", + " TensorDataset(X_train, y_train), batch_size=128, shuffle=True\n", + ")\n", + "\n", + "n_iter = 200\n", + "history = defaultdict(list)\n", + "for i_iter in range(n_iter):\n", + " for i_batch, (X_batch, y_batch) in enumerate(train_dl):\n", + " batch_history = defaultdict(list)\n", + " optimizer.zero_grad()\n", + " loss = -loss_func(model(X_batch), y_batch)\n", + " loss.backward()\n", + " optimizer.step()\n", + " param_dct = dict()\n", + " param_dct.update(extract_model_params(model, try_item=True))\n", + " param_dct.update(extract_model_params(likelihood, try_item=True))\n", + " for p_name, p_val in param_dct.items():\n", + " if isinstance(p_val, float):\n", + " batch_history[p_name].append(p_val)\n", + " batch_history[\"loss\"].append(loss.item())\n", + " for p_name, p_lst in batch_history.items():\n", + " history[p_name].append(np.mean(p_lst))\n", + " if (i_iter + 1) % 10 == 0:\n", + " print(f\"iter {i_iter + 1}/{n_iter}, {loss=:.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc39b5d9", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot scalar hyper params and loss (ELBO) convergence\n", + "ncols = len(history)\n", + "fig, axs = plt.subplots(\n", + " ncols=ncols, nrows=1, figsize=(ncols * 3, 3), layout=\"compressed\"\n", + ")\n", + "with torch.no_grad():\n", + " for ax, (p_name, p_lst) in zip(axs, history.items()):\n", + " ax.plot(p_lst)\n", + " ax.set_title(p_name)\n", + " ax.set_xlabel(\"iterations\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4907e39", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# Values of optimized hyper params\n", + "print(\"model params:\")\n", + "pprint(extract_model_params(model))\n", + "print(\"likelihood params:\")\n", + "pprint(extract_model_params(likelihood))" + ] + }, + { + "cell_type": "markdown", + "id": "4ae35e11", + "metadata": {}, + "source": [ + "# Run prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "629b4658", + "metadata": {}, + "outputs": [], + "source": [ + "# Evaluation (predictive posterior) mode\n", + "model.eval()\n", + "likelihood.eval()\n", + "\n", + "with torch.no_grad():\n", + " M = 10\n", + " post_pred_f = model(X_pred)\n", + " post_pred_y = likelihood(model(X_pred))\n", + "\n", + " fig, axs = plt.subplots(ncols=2, figsize=(14, 5), sharex=True, sharey=True)\n", + " fig_sigmas, ax_sigmas = plt.subplots()\n", + " for ii, (ax, post_pred, name, title) in enumerate(\n", + " zip(\n", + " axs,\n", + " [post_pred_f, post_pred_y],\n", + " [\"f\", \"y\"],\n", + " [\"epistemic uncertainty\", \"total uncertainty\"],\n", + " )\n", + " ):\n", + " yf_mean = post_pred.mean\n", + " yf_samples = post_pred.sample(sample_shape=torch.Size((M,)))\n", + "\n", + " yf_std = post_pred.stddev\n", + " lower = yf_mean - 2 * yf_std\n", + " upper = yf_mean + 2 * yf_std\n", + " ax.plot(\n", + " X_train.numpy(),\n", + " y_train.numpy(),\n", + " \"o\",\n", + " label=\"data\",\n", + " color=\"tab:blue\",\n", + " )\n", + " ax.plot(\n", + " X_pred.numpy(),\n", + " yf_mean.numpy(),\n", + " label=\"mean\",\n", + " color=\"tab:red\",\n", + " lw=2,\n", + " )\n", + " ax.plot(\n", + " X_pred.numpy(),\n", + " y_gt_pred.numpy(),\n", + " label=\"ground truth\",\n", + " color=\"k\",\n", + " lw=2,\n", + " ls=\"--\",\n", + " )\n", + " ax.fill_between(\n", + " X_pred.numpy(),\n", + " lower.numpy(),\n", + " upper.numpy(),\n", + " label=\"confidence\",\n", + " color=\"tab:orange\",\n", + " alpha=0.3,\n", + " )\n", + " ax.set_title(f\"confidence = {title}\")\n", + " if name == \"f\":\n", + " sigma_label = r\"epistemic: $\\pm 2\\sqrt{\\mathrm{diag}(\\Sigma_*)}$\"\n", + " zorder = 1\n", + " else:\n", + " sigma_label = (\n", + " r\"total: $\\pm 2\\sqrt{\\mathrm{diag}(\\Sigma_* + \\sigma_n^2\\,I)}$\"\n", + " )\n", + " zorder = 0\n", + " ax_sigmas.fill_between(\n", + " X_pred.numpy(),\n", + " lower.numpy(),\n", + " upper.numpy(),\n", + " label=sigma_label,\n", + " color=\"tab:orange\" if name == \"f\" else \"tab:blue\",\n", + " alpha=0.5,\n", + " zorder=zorder,\n", + " )\n", + " y_min = y_train.min()\n", + " y_max = y_train.max()\n", + " y_span = y_max - y_min\n", + " ax.set_ylim([y_min - 0.3 * y_span, y_max + 0.3 * y_span])\n", + " plot_samples(ax, X_pred, yf_samples, label=\"posterior pred. samples\")\n", + " if ii == 1:\n", + " ax.legend()\n", + " ax_sigmas.set_title(\"total vs. epistemic uncertainty\")\n", + " ax_sigmas.legend()" + ] + }, + { + "cell_type": "markdown", + "id": "60abb7ec", + "metadata": {}, + "source": [ + "# Let's check the learned noise" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8a32f5b", + "metadata": {}, + "outputs": [], + "source": [ + "# Target noise to learn\n", + "print(\"data noise:\", noise_std)\n", + "\n", + "# The two below must be the same\n", + "print(\n", + " \"learned noise:\",\n", + " (post_pred_y.stddev**2 - post_pred_f.stddev**2).mean().sqrt().item(),\n", + ")\n", + "print(\n", + " \"learned noise:\",\n", + " np.sqrt(\n", + " extract_model_params(likelihood, try_item=True)[\"noise_covar.noise\"]\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0637729a", + "metadata": {}, + "outputs": [], + "source": [ + "# When running as script\n", + "if not is_interactive():\n", + " plt.show()" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,py:light" + }, + "kernelspec": { + "display_name": "bayes-ml-course", + "language": "python", + "name": "bayes-ml-course" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}