diff --git a/BLcourse2.3/01_one_dim.ipynb b/BLcourse2.3/01_one_dim.ipynb index f1c2cae93c5a85f2ae9208cb07b48adbaf166ba7..f5208de38e0c4b0cc1a1ed052d6d9393a318aa04 100644 --- a/BLcourse2.3/01_one_dim.ipynb +++ b/BLcourse2.3/01_one_dim.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "5ece8acf", + "id": "2a8ed8ae", "metadata": {}, "source": [ "# Notation\n", @@ -26,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "71f86676", + "id": "ae440984", "metadata": {}, "source": [ "# Imports, helpers, setup" @@ -35,7 +35,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c581079", + "id": "25b39064", "metadata": {}, "outputs": [], "source": [ @@ -47,7 +47,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2f89a0c6", + "id": "960de85c", "metadata": {}, "outputs": [], "source": [ @@ -82,7 +82,7 @@ }, { "cell_type": "markdown", - "id": "97e05998", + "id": "7a41bb4f", "metadata": { "lines_to_next_cell": 2 }, @@ -99,7 +99,7 @@ { "cell_type": "code", "execution_count": null, - "id": "06c7e542", + "id": "30019345", "metadata": {}, "outputs": [], "source": [ @@ -127,7 +127,7 @@ }, { "cell_type": "markdown", - "id": "28539d2a", + "id": "7d326cde", "metadata": { "lines_to_next_cell": 2 }, @@ -156,7 +156,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84c29689", + "id": "0ce2844b", "metadata": { "lines_to_next_cell": 2 }, @@ -193,7 +193,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ce3d687b", + "id": "d8a2aab8", "metadata": {}, "outputs": [], "source": [ @@ -204,7 +204,7 @@ { "cell_type": "code", "execution_count": null, - "id": "332904ac", + "id": "0a3989c5", "metadata": {}, "outputs": [], "source": [ @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5142e002", + "id": "5c3808b3", "metadata": { "lines_to_next_cell": 2 }, @@ -232,7 +232,7 @@ }, { "cell_type": "markdown", - "id": "2c09f48a", + "id": "c3600af3", "metadata": {}, "source": [ "# Sample from the GP prior\n", @@ -247,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9d766f10", + "id": "e74b8d1b", "metadata": { "lines_to_next_cell": 2 }, @@ -283,7 +283,7 @@ }, { "cell_type": "markdown", - "id": "e31330a2", + "id": "0598b567", "metadata": {}, "source": [ "Let's investigate the samples more closely. A constant mean $\\ve m(\\ma X) =\n", @@ -296,7 +296,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e2372d76", + "id": "6d7fa03e", "metadata": {}, "outputs": [], "source": [ @@ -310,7 +310,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cb1d5d14", + "id": "20b6e415", "metadata": {}, "outputs": [], "source": [ @@ -324,14 +324,84 @@ }, { "cell_type": "markdown", - "id": "9db57693", + "id": "25b90532", + "metadata": {}, + "source": [ + "# GP posterior predictive distribution with fixed hyper params\n", + "\n", + "Now we calculate the posterior predictive distribution $p(\\test{\\predve\n", + "f}|\\test{\\ma X}, \\ma X, \\ve y)$, i.e. we condition on the train data (Bayesian\n", + "inference).\n", + "\n", + "We use the fixed hyper param values defined above. In particular, since\n", + "$\\sigma_n^2$ = `model.likelihood.noise_covar.noise` is > 0, we have a\n", + "regression setting." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d01951f1", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# Evaluation (predictive posterior) mode\n", + "model.eval()\n", + "likelihood.eval()\n", + "\n", + "with torch.no_grad():\n", + " M = 10\n", + " post_pred_f = model(X_pred)\n", + "\n", + " fig, ax = plt.subplots()\n", + " f_mean = post_pred_f.mean\n", + " f_samples = post_pred_f.sample(sample_shape=torch.Size((M,)))\n", + " f_std = post_pred_f.stddev\n", + " lower = f_mean - 2 * f_std\n", + " upper = f_mean + 2 * f_std\n", + " ax.plot(\n", + " X_train.numpy(),\n", + " y_train.numpy(),\n", + " \"o\",\n", + " label=\"data\",\n", + " color=\"tab:blue\",\n", + " )\n", + " ax.plot(\n", + " X_pred.numpy(),\n", + " f_mean.numpy(),\n", + " label=\"mean\",\n", + " color=\"tab:red\",\n", + " lw=2,\n", + " )\n", + " ax.fill_between(\n", + " X_pred.numpy(),\n", + " lower.numpy(),\n", + " upper.numpy(),\n", + " label=\"confidence\",\n", + " color=\"tab:orange\",\n", + " alpha=0.3,\n", + " )\n", + " y_min = y_train.min()\n", + " y_max = y_train.max()\n", + " y_span = y_max - y_min\n", + " ax.set_ylim([y_min - 0.3 * y_span, y_max + 0.3 * y_span])\n", + " plot_samples(ax, X_pred, f_samples, label=\"posterior pred. samples\")\n", + " ax.legend()" + ] + }, + { + "cell_type": "markdown", + "id": "b960a745", "metadata": {}, "source": [ "# Fit GP to data: optimize hyper params\n", "\n", "In each step of the optimizer, we condition on the training data (e.g. do\n", - "Bayesian inference) to calculate the weight posterior for the current values\n", - "of the hyper params.\n", + "Bayesian inference) to calculate the posterior predictive distribution for\n", + "the current values of the hyper params. We iterate until the log marginal\n", + "likelihood is converged.\n", "\n", "We use a simplistic PyTorch-style hand written train loop without convergence\n", "control, so make sure to use enough `n_iter` and eyeball-check that the loss\n", @@ -344,7 +414,7 @@ { "cell_type": "code", "execution_count": null, - "id": "58266bb8", + "id": "dc059dc2", "metadata": {}, "outputs": [], "source": [ @@ -372,7 +442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2cc3092c", + "id": "7c9e7052", "metadata": {}, "outputs": [], "source": [ @@ -388,7 +458,7 @@ { "cell_type": "code", "execution_count": null, - "id": "030d1fc8", + "id": "eb1fe908", "metadata": {}, "outputs": [], "source": [ @@ -398,7 +468,7 @@ }, { "cell_type": "markdown", - "id": "bb0c1621", + "id": "98aefb90", "metadata": {}, "source": [ "# Run prediction\n", @@ -419,7 +489,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5821ae0f", + "id": "a78de0e4", "metadata": {}, "outputs": [], "source": [ @@ -428,6 +498,7 @@ "likelihood.eval()\n", "\n", "with torch.no_grad():\n", + " M = 10\n", " post_pred_f = model(X_pred)\n", " post_pred_y = likelihood(model(X_pred))\n", "\n", @@ -437,9 +508,8 @@ " zip(axs, [post_pred_f, post_pred_y], [\"f\", \"y\"])\n", " ):\n", " yf_mean = post_pred.mean\n", - " yf_samples = post_pred.sample(sample_shape=torch.Size((10,)))\n", + " yf_samples = post_pred.sample(sample_shape=torch.Size((M,)))\n", "\n", - " ##lower, upper = post_pred.confidence_region()\n", " yf_std = post_pred.stddev\n", " lower = yf_mean - 2 * yf_std\n", " upper = yf_mean + 2 * yf_std\n", @@ -489,8 +559,16 @@ " plot_samples(ax, X_pred, yf_samples, label=\"posterior pred. samples\")\n", " if ii == 1:\n", " ax.legend()\n", - " ax_sigmas.legend()\n", - "\n", + " ax_sigmas.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e726d3d1", + "metadata": {}, + "outputs": [], + "source": [ "# When running as script\n", "if not is_interactive():\n", " plt.show()" diff --git a/BLcourse2.3/02_two_dim.ipynb b/BLcourse2.3/02_two_dim.ipynb index 4216214ffb5d21cb15bdfe4935cd99f9b581acb8..a87daa957597ef0730b108251a6f5c459c7c8eec 100644 --- a/BLcourse2.3/02_two_dim.ipynb +++ b/BLcourse2.3/02_two_dim.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7cc3b0fe", + "id": "b88a37d9", "metadata": {}, "outputs": [], "source": [ @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "194927e4", + "id": "64c21c28", "metadata": { "lines_to_next_cell": 2 }, @@ -28,18 +28,17 @@ "import gpytorch\n", "from matplotlib import pyplot as plt\n", "from matplotlib import is_interactive\n", - "from mpl_toolkits.mplot3d import Axes3D\n", "import numpy as np\n", "\n", "from sklearn.preprocessing import StandardScaler\n", "\n", - "from utils import extract_model_params, plot_samples, fig_ax_3d" + "from utils import extract_model_params, fig_ax_3d" ] }, { "cell_type": "code", "execution_count": null, - "id": "44d9d58d", + "id": "71da09db", "metadata": {}, "outputs": [], "source": [ @@ -49,7 +48,7 @@ }, { "cell_type": "markdown", - "id": "4aa6e6bd", + "id": "e5786965", "metadata": { "lines_to_next_cell": 2 }, @@ -60,7 +59,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3d0d40c6", + "id": "ee7a8e4a", "metadata": {}, "outputs": [], "source": [ @@ -116,7 +115,7 @@ { "cell_type": "code", "execution_count": null, - "id": "87f84642", + "id": "0058371a", "metadata": {}, "outputs": [], "source": [ @@ -141,7 +140,7 @@ }, { "cell_type": "markdown", - "id": "158c3da2", + "id": "3523c08d", "metadata": {}, "source": [ "# Exercise 1" @@ -150,7 +149,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eb9e4ee1", + "id": "0ec4ad3d", "metadata": {}, "outputs": [], "source": [ @@ -160,7 +159,7 @@ }, { "cell_type": "markdown", - "id": "197f7480", + "id": "f4592094", "metadata": {}, "source": [ "# Exercise 2" @@ -169,7 +168,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f1ba561", + "id": "6bec0f58", "metadata": {}, "outputs": [], "source": [ @@ -179,7 +178,7 @@ }, { "cell_type": "markdown", - "id": "56b080da", + "id": "865866ea", "metadata": {}, "source": [ "# Exercise 3" @@ -188,7 +187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "77e9779e", + "id": "5132eaeb", "metadata": {}, "outputs": [], "source": [ @@ -199,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d3f0b5d7", + "id": "6efe09b3", "metadata": {}, "outputs": [], "source": [ @@ -217,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6f6b3548", + "id": "af637911", "metadata": {}, "outputs": [], "source": [ @@ -232,7 +231,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30dbecca", + "id": "eab0f4fa", "metadata": {}, "outputs": [], "source": [ @@ -257,7 +256,7 @@ }, { "cell_type": "markdown", - "id": "56d27ef1", + "id": "a00bf4e4", "metadata": { "lines_to_next_cell": 2 }, @@ -268,7 +267,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6da6f37a", + "id": "52834c9f", "metadata": {}, "outputs": [], "source": [ @@ -303,7 +302,7 @@ { "cell_type": "code", "execution_count": null, - "id": "50695462", + "id": "1fcef3dc", "metadata": {}, "outputs": [], "source": [ @@ -314,7 +313,7 @@ { "cell_type": "code", "execution_count": null, - "id": "393a458d", + "id": "25397c1e", "metadata": {}, "outputs": [], "source": [ @@ -325,7 +324,7 @@ { "cell_type": "code", "execution_count": null, - "id": "95ce5bf2", + "id": "a94067cc", "metadata": {}, "outputs": [], "source": [ @@ -340,7 +339,7 @@ }, { "cell_type": "markdown", - "id": "56ccbafa", + "id": "a5b1a9ee", "metadata": {}, "source": [ "# Fit GP to data: optimize hyper params" @@ -349,7 +348,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d44f3667", + "id": "c15e6d45", "metadata": {}, "outputs": [], "source": [ @@ -377,7 +376,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d1ff9c36", + "id": "0c7a4643", "metadata": {}, "outputs": [], "source": [ @@ -392,7 +391,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c6c9c1b2", + "id": "5efa3b9d", "metadata": {}, "outputs": [], "source": [ @@ -402,7 +401,7 @@ }, { "cell_type": "markdown", - "id": "b01dac44", + "id": "898f74f7", "metadata": {}, "source": [ "# Run prediction" @@ -411,7 +410,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5bb6926d", + "id": "59e18623", "metadata": { "lines_to_next_cell": 2 }, @@ -448,7 +447,7 @@ }, { "cell_type": "markdown", - "id": "9a354f60", + "id": "591e453d", "metadata": {}, "source": [ "# Plot difference to ground truth and uncertainty" @@ -457,7 +456,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f3570739", + "id": "a7f294c4", "metadata": {}, "outputs": [], "source": [ @@ -509,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "9bfc86ee", + "id": "04257cea", "metadata": {}, "source": [ "# Check learned noise" @@ -518,7 +517,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6ad91b3b", + "id": "9d966f54", "metadata": {}, "outputs": [], "source": [ @@ -533,7 +532,7 @@ }, { "cell_type": "markdown", - "id": "18e5a5ee", + "id": "1da209ff", "metadata": {}, "source": [ "# Plot confidence bands" @@ -542,7 +541,7 @@ { "cell_type": "code", "execution_count": null, - "id": "46c9aab5", + "id": "7f56936d", "metadata": {}, "outputs": [], "source": [ @@ -570,7 +569,7 @@ { "cell_type": "code", "execution_count": null, - "id": "81788c9d", + "id": "d25b4506", "metadata": {}, "outputs": [], "source": [