Skip to content
Snippets Groups Projects
Commit ac660f5b authored by Steve Schmerler's avatar Steve Schmerler
Browse files

Better VI theory part in 03_one_dim_SVI

parent 6e4ff1ee
Branches feature-gp
No related tags found
1 merge request!3Update 03_one_dim_SVI
%% Cell type:markdown id:b146bc52 tags: %% Cell type:markdown id:b146bc52 tags:
# About # About
In this notebook, we replace the ExactGP inference and log marginal In this notebook, we replace the ExactGP inference and log marginal
likelihood optimization by using sparse stochastic variational inference. likelihood optimization by using sparse stochastic variational inference.
This serves as an example of the many methods `gpytorch` offers to make GPs This serves as an example of the many methods `gpytorch` offers to make GPs
scale to large data sets. scale to large data sets.
$\newcommand{\ve}[1]{\mathit{\boldsymbol{#1}}}$ $\newcommand{\ve}[1]{\mathit{\boldsymbol{#1}}}$
$\newcommand{\ma}[1]{\mathbf{#1}}$ $\newcommand{\ma}[1]{\mathbf{#1}}$
$\newcommand{\pred}[1]{\rm{#1}}$ $\newcommand{\pred}[1]{\rm{#1}}$
$\newcommand{\predve}[1]{\mathbf{#1}}$ $\newcommand{\predve}[1]{\mathbf{#1}}$
$\newcommand{\test}[1]{#1_*}$ $\newcommand{\test}[1]{#1_*}$
$\newcommand{\testtest}[1]{#1_{**}}$ $\newcommand{\testtest}[1]{#1_{**}}$
$\newcommand{\dd}{\rm{d}}$ $\newcommand{\dd}{{\rm{d}}}$
$\newcommand{\lt}[1]{_{\text{#1}}}$
$\DeclareMathOperator{\diag}{diag}$ $\DeclareMathOperator{\diag}{diag}$
$\DeclareMathOperator{\cov}{cov}$ $\DeclareMathOperator{\cov}{cov}$
%% Cell type:markdown id:f96e3304 tags: %% Cell type:markdown id:f96e3304 tags:
# Imports, helpers, setup # Imports, helpers, setup
%% Cell type:code id:193b2f13 tags: %% Cell type:code id:193b2f13 tags:
``` python ``` python
##%matplotlib notebook ##%matplotlib notebook
%matplotlib widget %matplotlib widget
##%matplotlib inline ##%matplotlib inline
``` ```
%% Cell type:code id:c11c1030 tags: %% Cell type:code id:c11c1030 tags:
``` python ``` python
import math import math
from collections import defaultdict from collections import defaultdict
from pprint import pprint from pprint import pprint
import torch import torch
import gpytorch import gpytorch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib import is_interactive from matplotlib import is_interactive
import numpy as np import numpy as np
from torch.utils.data import TensorDataset, DataLoader from torch.utils.data import TensorDataset, DataLoader
from utils import extract_model_params, plot_samples from utils import extract_model_params, plot_samples
torch.set_default_dtype(torch.float64) torch.set_default_dtype(torch.float64)
torch.manual_seed(123) torch.manual_seed(123)
``` ```
%% Cell type:markdown id:460d4c3d tags: %% Cell type:markdown id:460d4c3d tags:
# Generate toy 1D data # Generate toy 1D data
Now we generate 10x more points as in the ExactGP case, still the inference Now we generate 10x more points as in the ExactGP case, still the inference
won't be much slower (exact GPs scale roughly as $N^3$). Note that the data we won't be much slower (exact GPs scale roughly as $N^3$). Note that the data we
use here is still tiny (1000 points is easy even for exact GPs), so the use here is still tiny (1000 points is easy even for exact GPs), so the
method's usefulness cannot be fully exploited in our example method's usefulness cannot be fully exploited in our example
-- also we don't even use a GPU yet :). -- also we don't even use a GPU yet :).
%% Cell type:code id:f1a9647d tags: %% Cell type:code id:f1a9647d tags:
``` python ``` python
def ground_truth(x, const): def ground_truth(x, const):
return torch.sin(x) * torch.exp(-0.2 * x) + const return torch.sin(x) * torch.exp(-0.2 * x) + const
def generate_data(x, gaps=[[1, 3]], const=None, noise_std=None): def generate_data(x, gaps=[[1, 3]], const=None, noise_std=None):
noise_dist = torch.distributions.Normal(loc=0, scale=noise_std) noise_dist = torch.distributions.Normal(loc=0, scale=noise_std)
y = ground_truth(x, const=const) + noise_dist.sample( y = ground_truth(x, const=const) + noise_dist.sample(
sample_shape=(len(x),) sample_shape=(len(x),)
) )
msk = torch.tensor([True] * len(x)) msk = torch.tensor([True] * len(x))
if gaps is not None: if gaps is not None:
for g in gaps: for g in gaps:
msk = msk & ~((x > g[0]) & (x < g[1])) msk = msk & ~((x > g[0]) & (x < g[1]))
return x[msk], y[msk], y return x[msk], y[msk], y
const = 5.0 const = 5.0
noise_std = 0.1 noise_std = 0.1
x = torch.linspace(0, 4 * math.pi, 1000) x = torch.linspace(0, 4 * math.pi, 1000)
X_train, y_train, y_gt_train = generate_data( X_train, y_train, y_gt_train = generate_data(
x, gaps=[[6, 10]], const=const, noise_std=noise_std x, gaps=[[6, 10]], const=const, noise_std=noise_std
) )
X_pred = torch.linspace( X_pred = torch.linspace(
X_train[0] - 2, X_train[-1] + 2, 200, requires_grad=False X_train[0] - 2, X_train[-1] + 2, 200, requires_grad=False
) )
y_gt_pred = ground_truth(X_pred, const=const) y_gt_pred = ground_truth(X_pred, const=const)
print(f"{X_train.shape=}") print(f"{X_train.shape=}")
print(f"{y_train.shape=}") print(f"{y_train.shape=}")
print(f"{X_pred.shape=}") print(f"{X_pred.shape=}")
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.scatter(X_train, y_train, marker="o", color="tab:blue", label="noisy data") ax.scatter(X_train, y_train, marker="o", color="tab:blue", label="noisy data")
ax.plot(X_pred, y_gt_pred, ls="--", color="k", label="ground truth") ax.plot(X_pred, y_gt_pred, ls="--", color="k", label="ground truth")
ax.legend() ax.legend()
``` ```
%% Cell type:markdown id:4ad60c5e tags: %% Cell type:markdown id:4ad60c5e tags:
# Define GP model # Define GP model
The model follows [this The model follows [this
example](https://docs.gpytorch.ai/en/stable/examples/04_Variational_and_Approximate_GPs/SVGP_Regression_CUDA.html) example](https://docs.gpytorch.ai/en/stable/examples/04_Variational_and_Approximate_GPs/SVGP_Regression_CUDA.html)
based on [Hensman et al., "Scalable Variational Gaussian Process Classification", based on [Hensman et al., "Scalable Variational Gaussian Process Classification",
2015](https://proceedings.mlr.press/v38/hensman15.html). The model is 2015](https://proceedings.mlr.press/v38/hensman15.html). The model is
"sparse" since it works with a set of *inducing* points $(\ma Z, \ve u), "sparse" since it works with a set of *inducing* points $(\ma Z, \ve u),
\ve u=f(\ma Z)$ which is much smaller than the train data $(\ma X, \ve y)$. \ve u=f(\ma Z)$ which is much smaller than the train data $(\ma X, \ve y)$.
See also [the GPJax See also [the GPJax
docs](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi) for a docs](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi) for a
nice introduction. nice introduction.
We have the same hyper parameters as before We have the same hyper parameters as before
* $\ell$ = `model.covar_module.base_kernel.lengthscale` * $\ell$ = `model.covar_module.base_kernel.lengthscale`
* $\sigma_n^2$ = `likelihood.noise_covar.noise` * $\sigma_n^2$ = `likelihood.noise_covar.noise`
* $s$ = `model.covar_module.outputscale` * $s$ = `model.covar_module.outputscale`
* $m(\ve x) = c$ = `model.mean_module.constant` * $m(\ve x) = c$ = `model.mean_module.constant`
plus additional ones, introduced by the approximations used: plus additional ones, introduced by the approximations used:
* the learnable inducing points $\ma Z$ for the variational distribution * the learnable inducing points $\ma Z$ for the variational distribution
$q_{\ve\psi}(\ve u)$ $q_{\ve\psi}(\ve u)$
* learnable parameters of the variational distribution $q_{\ve\psi}(\ve u)=\mathcal N(\ve m_u, \ma S)$: the * learnable parameters of the variational distribution $q_{\ve\psi}(\ve u)=\mathcal N(\ve m_u, \ma S)$: the
variational mean $\ve m_u$ and covariance $\ma S$ in form a lower triangular variational mean $\ve m_u$ and covariance $\ma S$ in form a lower triangular
matrix $\ma L$ such that $\ma S=\ma L\,\ma L^\top$ matrix $\ma L$ such that $\ma S=\ma L\,\ma L^\top$
%% Cell type:code id:444929a3 tags: %% Cell type:code id:444929a3 tags:
``` python ``` python
class ApproxGPModel(gpytorch.models.ApproximateGP): class ApproxGPModel(gpytorch.models.ApproximateGP):
def __init__(self, Z): def __init__(self, Z):
# Approximate inducing value posterior q(u), u = f(Z), Z = inducing # Approximate inducing value posterior q(u), u = f(Z), Z = inducing
# points (subset of X_train) # points (subset of X_train)
variational_distribution = ( variational_distribution = (
gpytorch.variational.CholeskyVariationalDistribution(Z.size(0)) gpytorch.variational.CholeskyVariationalDistribution(Z.size(0))
) )
# Compute q(f(X)) from q(u) # Compute q(f(X)) from q(u)
variational_strategy = gpytorch.variational.VariationalStrategy( variational_strategy = gpytorch.variational.VariationalStrategy(
self, self,
Z, Z,
variational_distribution, variational_distribution,
learn_inducing_locations=True, learn_inducing_locations=True,
) )
super().__init__(variational_strategy) super().__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean() self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel( self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel() gpytorch.kernels.RBFKernel()
) )
def forward(self, x): def forward(self, x):
mean_x = self.mean_module(x) mean_x = self.mean_module(x)
covar_x = self.covar_module(x) covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
likelihood = gpytorch.likelihoods.GaussianLikelihood() likelihood = gpytorch.likelihoods.GaussianLikelihood()
n_train = len(X_train) n_train = len(X_train)
# Start values for inducing points Z, use 10% random sub-sample of X_train. # Start values for inducing points Z, use 10% random sub-sample of X_train.
ind_points_fraction = 0.1 ind_points_fraction = 0.1
ind_idxs = torch.randperm(n_train)[: int(n_train * ind_points_fraction)] ind_idxs = torch.randperm(n_train)[: int(n_train * ind_points_fraction)]
model = ApproxGPModel(Z=X_train[ind_idxs]) model = ApproxGPModel(Z=X_train[ind_idxs])
``` ```
%% Cell type:code id:5861163d tags: %% Cell type:code id:5861163d tags:
``` python ``` python
# Inspect the model # Inspect the model
print(model) print(model)
``` ```
%% Cell type:code id:98032457 tags: %% Cell type:code id:98032457 tags:
``` python ``` python
# Inspect the likelihood. In contrast to ExactGP, the likelihood is not part of # Inspect the likelihood. In contrast to ExactGP, the likelihood is not part of
# the GP model instance. # the GP model instance.
print(likelihood) print(likelihood)
``` ```
%% Cell type:code id:7ec97687 tags: %% Cell type:code id:7ec97687 tags:
``` python ``` python
# Default start hyper params # Default start hyper params
print("model params:") print("model params:")
pprint(extract_model_params(model)) pprint(extract_model_params(model))
print("likelihood params:") print("likelihood params:")
pprint(extract_model_params(likelihood)) pprint(extract_model_params(likelihood))
``` ```
%% Cell type:code id:167d584c tags: %% Cell type:code id:167d584c tags:
``` python ``` python
# Set new start hyper params (scalars only) # Set new start hyper params (scalars only)
model.mean_module.constant = 3.0 model.mean_module.constant = 3.0
model.covar_module.base_kernel.lengthscale = 1.0 model.covar_module.base_kernel.lengthscale = 1.0
model.covar_module.outputscale = 1.0 model.covar_module.outputscale = 1.0
likelihood.noise_covar.noise = 0.3 likelihood.noise_covar.noise = 0.3
``` ```
%% Cell type:markdown id:6704ce2f tags: %% Cell type:markdown id:6704ce2f tags:
# Fit GP to data: optimize hyper params # Fit GP to data: optimize hyper params
Now we optimize the GP hyper parameters by doing a GP-specific variational inference (VI), Now we optimize the GP hyper parameters by doing a GP-specific variational
where we optimize not the log marginal likelihood (ExactGP case), inference (VI), where we don't maximize the log marginal likelihood (ExactGP
but an ELBO (evidence lower bound) objective. The latter is a proxy for minimizing case), but an ELBO ("evidence lower bound") objective -- a lower bound on the
the KL divergence between distributions, which in our case are the approximate marginal likelihood (the "evidence"). In variational inference, an ELBO objective
shows up when minimizing the KL divergence between
an approximate and the true posterior
$$
p(w|y) = \frac{p(y|w)\,p(w)}{\int p(y|w)\,p(w)\,\dd w}
= \frac{p(y|w)\,p(w)}{p(y)}
$$
$$
\ve\zeta^* = \text{arg}\min_{\ve\zeta} D\lt{KL}(q_{\ve\zeta}(w)\,\Vert\, p(w|y))
$$
to obtain the optimal variational parameters $\ve\zeta^*$ to approximate
$p(w|y)$ with $q_{\ve\zeta^*}(w)$.
In our case the two distributions are the approximate
$$q_{\ve\zeta}(\mathbf f)=\int p(\mathbf f|\ve u)\,q_{\ve\psi}(\ve u)\,\dd\ve u\quad(\text{"variational strategy"})$$ $$q_{\ve\zeta}(\mathbf f)=\int p(\mathbf f|\ve u)\,q_{\ve\psi}(\ve u)\,\dd\ve u\quad(\text{"variational strategy"})$$
and the true $p(\mathbf f|\mathcal D)$ posterior over function values. We optimize with respect to and the true $p(\mathbf f|\mathcal D)$ posterior over function values. We
optimize with respect to
$$\ve\zeta = [\ell, \sigma_n^2, s, c, \ve\psi] $$ $$\ve\zeta = [\ell, \sigma_n^2, s, c, \ve\psi] $$
with with
$$\ve\psi = [\ve m_u, \ma Z, \ma L]$$ $$\ve\psi = [\ve m_u, \ma Z, \ma L]$$
the parameters of the variational distribution $q_{\ve\psi}(\ve u)$. the parameters of the variational distribution $q_{\ve\psi}(\ve u)$.
In addition, we perform a stochastic In addition, we perform a stochastic
optimization by using a deep learning type mini-batch loop, hence optimization by using a deep learning type mini-batch loop, hence
"stochastic" variational inference (SVI). The latter speeds up the "stochastic" variational inference (SVI). The latter speeds up the
optimization since we only look at a fraction of data per optimizer step to optimization since we only look at a fraction of data per optimizer step to
calculate an approximate loss gradient (`loss.backward()`). calculate an approximate loss gradient (`loss.backward()`).
%% Cell type:code id:f39d86f8 tags: %% Cell type:code id:f39d86f8 tags:
``` python ``` python
# Train mode # Train mode
model.train() model.train()
likelihood.train() likelihood.train()
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
[dict(params=model.parameters()), dict(params=likelihood.parameters())], [dict(params=model.parameters()), dict(params=likelihood.parameters())],
lr=0.1, lr=0.1,
) )
loss_func = gpytorch.mlls.VariationalELBO( loss_func = gpytorch.mlls.VariationalELBO(
likelihood, model, num_data=X_train.shape[0] likelihood, model, num_data=X_train.shape[0]
) )
train_dl = DataLoader( train_dl = DataLoader(
TensorDataset(X_train, y_train), batch_size=128, shuffle=True TensorDataset(X_train, y_train), batch_size=128, shuffle=True
) )
n_iter = 200 n_iter = 200
history = defaultdict(list) history = defaultdict(list)
for i_iter in range(n_iter): for i_iter in range(n_iter):
for i_batch, (X_batch, y_batch) in enumerate(train_dl): for i_batch, (X_batch, y_batch) in enumerate(train_dl):
batch_history = defaultdict(list) batch_history = defaultdict(list)
optimizer.zero_grad() optimizer.zero_grad()
loss = -loss_func(model(X_batch), y_batch) loss = -loss_func(model(X_batch), y_batch)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
param_dct = dict() param_dct = dict()
param_dct.update(extract_model_params(model, try_item=True)) param_dct.update(extract_model_params(model, try_item=True))
param_dct.update(extract_model_params(likelihood, try_item=True)) param_dct.update(extract_model_params(likelihood, try_item=True))
for p_name, p_val in param_dct.items(): for p_name, p_val in param_dct.items():
if isinstance(p_val, float): if isinstance(p_val, float):
batch_history[p_name].append(p_val) batch_history[p_name].append(p_val)
batch_history["loss"].append(loss.item()) batch_history["loss"].append(loss.item())
for p_name, p_lst in batch_history.items(): for p_name, p_lst in batch_history.items():
history[p_name].append(np.mean(p_lst)) history[p_name].append(np.mean(p_lst))
if (i_iter + 1) % 10 == 0: if (i_iter + 1) % 10 == 0:
print(f"iter {i_iter + 1}/{n_iter}, {loss=:.3f}") print(f"iter {i_iter + 1}/{n_iter}, {loss=:.3f}")
``` ```
%% Cell type:code id:dc39b5d9 tags: %% Cell type:code id:dc39b5d9 tags:
``` python ``` python
# Plot scalar hyper params and loss (ELBO) convergence # Plot scalar hyper params and loss (ELBO) convergence
ncols = len(history) ncols = len(history)
fig, axs = plt.subplots( fig, axs = plt.subplots(
ncols=ncols, nrows=1, figsize=(ncols * 3, 3), layout="compressed" ncols=ncols, nrows=1, figsize=(ncols * 3, 3), layout="compressed"
) )
with torch.no_grad(): with torch.no_grad():
for ax, (p_name, p_lst) in zip(axs, history.items()): for ax, (p_name, p_lst) in zip(axs, history.items()):
ax.plot(p_lst) ax.plot(p_lst)
ax.set_title(p_name) ax.set_title(p_name)
ax.set_xlabel("iterations") ax.set_xlabel("iterations")
``` ```
%% Cell type:code id:c4907e39 tags: %% Cell type:code id:c4907e39 tags:
``` python ``` python
# Values of optimized hyper params # Values of optimized hyper params
print("model params:") print("model params:")
pprint(extract_model_params(model)) pprint(extract_model_params(model))
print("likelihood params:") print("likelihood params:")
pprint(extract_model_params(likelihood)) pprint(extract_model_params(likelihood))
``` ```
%% Cell type:markdown id:4ae35e11 tags: %% Cell type:markdown id:4ae35e11 tags:
# Run prediction # Run prediction
%% Cell type:code id:629b4658 tags: %% Cell type:code id:629b4658 tags:
``` python ``` python
# Evaluation (predictive posterior) mode # Evaluation (predictive posterior) mode
model.eval() model.eval()
likelihood.eval() likelihood.eval()
with torch.no_grad(): with torch.no_grad():
M = 10 M = 10
post_pred_f = model(X_pred) post_pred_f = model(X_pred)
post_pred_y = likelihood(model(X_pred)) post_pred_y = likelihood(model(X_pred))
fig, axs = plt.subplots(ncols=2, figsize=(14, 5), sharex=True, sharey=True) fig, axs = plt.subplots(ncols=2, figsize=(14, 5), sharex=True, sharey=True)
fig_sigmas, ax_sigmas = plt.subplots() fig_sigmas, ax_sigmas = plt.subplots()
for ii, (ax, post_pred, name, title) in enumerate( for ii, (ax, post_pred, name, title) in enumerate(
zip( zip(
axs, axs,
[post_pred_f, post_pred_y], [post_pred_f, post_pred_y],
["f", "y"], ["f", "y"],
["epistemic uncertainty", "total uncertainty"], ["epistemic uncertainty", "total uncertainty"],
) )
): ):
yf_mean = post_pred.mean yf_mean = post_pred.mean
yf_samples = post_pred.sample(sample_shape=torch.Size((M,))) yf_samples = post_pred.sample(sample_shape=torch.Size((M,)))
yf_std = post_pred.stddev yf_std = post_pred.stddev
lower = yf_mean - 2 * yf_std lower = yf_mean - 2 * yf_std
upper = yf_mean + 2 * yf_std upper = yf_mean + 2 * yf_std
ax.plot( ax.plot(
X_train.numpy(), X_train.numpy(),
y_train.numpy(), y_train.numpy(),
"o", "o",
label="data", label="data",
color="tab:blue", color="tab:blue",
) )
ax.plot( ax.plot(
X_pred.numpy(), X_pred.numpy(),
yf_mean.numpy(), yf_mean.numpy(),
label="mean", label="mean",
color="tab:red", color="tab:red",
lw=2, lw=2,
) )
ax.plot( ax.plot(
X_pred.numpy(), X_pred.numpy(),
y_gt_pred.numpy(), y_gt_pred.numpy(),
label="ground truth", label="ground truth",
color="k", color="k",
lw=2, lw=2,
ls="--", ls="--",
) )
ax.fill_between( ax.fill_between(
X_pred.numpy(), X_pred.numpy(),
lower.numpy(), lower.numpy(),
upper.numpy(), upper.numpy(),
label="confidence", label="confidence",
color="tab:orange", color="tab:orange",
alpha=0.3, alpha=0.3,
) )
ax.set_title(f"confidence = {title}") ax.set_title(f"confidence = {title}")
if name == "f": if name == "f":
sigma_label = r"epistemic: $\pm 2\sqrt{\mathrm{diag}(\Sigma_*)}$" sigma_label = r"epistemic: $\pm 2\sqrt{\mathrm{diag}(\Sigma_*)}$"
zorder = 1 zorder = 1
else: else:
sigma_label = ( sigma_label = (
r"total: $\pm 2\sqrt{\mathrm{diag}(\Sigma_* + \sigma_n^2\,I)}$" r"total: $\pm 2\sqrt{\mathrm{diag}(\Sigma_* + \sigma_n^2\,I)}$"
) )
zorder = 0 zorder = 0
ax_sigmas.fill_between( ax_sigmas.fill_between(
X_pred.numpy(), X_pred.numpy(),
lower.numpy(), lower.numpy(),
upper.numpy(), upper.numpy(),
label=sigma_label, label=sigma_label,
color="tab:orange" if name == "f" else "tab:blue", color="tab:orange" if name == "f" else "tab:blue",
alpha=0.5, alpha=0.5,
zorder=zorder, zorder=zorder,
) )
y_min = y_train.min() y_min = y_train.min()
y_max = y_train.max() y_max = y_train.max()
y_span = y_max - y_min y_span = y_max - y_min
ax.set_ylim([y_min - 0.3 * y_span, y_max + 0.3 * y_span]) ax.set_ylim([y_min - 0.3 * y_span, y_max + 0.3 * y_span])
plot_samples(ax, X_pred, yf_samples, label="posterior pred. samples") plot_samples(ax, X_pred, yf_samples, label="posterior pred. samples")
if ii == 1: if ii == 1:
ax.legend() ax.legend()
ax_sigmas.set_title("total vs. epistemic uncertainty") ax_sigmas.set_title("total vs. epistemic uncertainty")
ax_sigmas.legend() ax_sigmas.legend()
``` ```
%% Cell type:markdown id:60abb7ec tags: %% Cell type:markdown id:60abb7ec tags:
# Let's check the learned noise # Let's check the learned noise
%% Cell type:code id:d8a32f5b tags: %% Cell type:code id:d8a32f5b tags:
``` python ``` python
# Target noise to learn # Target noise to learn
print("data noise:", noise_std) print("data noise:", noise_std)
# The two below must be the same # The two below must be the same
print( print(
"learned noise:", "learned noise:",
(post_pred_y.stddev**2 - post_pred_f.stddev**2).mean().sqrt().item(), (post_pred_y.stddev**2 - post_pred_f.stddev**2).mean().sqrt().item(),
) )
print( print(
"learned noise:", "learned noise:",
np.sqrt( np.sqrt(
extract_model_params(likelihood, try_item=True)["noise_covar.noise"] extract_model_params(likelihood, try_item=True)["noise_covar.noise"]
), ),
) )
``` ```
%% Cell type:code id:0637729a tags: %% Cell type:code id:0637729a tags:
``` python ``` python
# When running as script # When running as script
if not is_interactive(): if not is_interactive():
plt.show() plt.show()
``` ```
......
...@@ -25,7 +25,8 @@ ...@@ -25,7 +25,8 @@
# $\newcommand{\predve}[1]{\mathbf{#1}}$ # $\newcommand{\predve}[1]{\mathbf{#1}}$
# $\newcommand{\test}[1]{#1_*}$ # $\newcommand{\test}[1]{#1_*}$
# $\newcommand{\testtest}[1]{#1_{**}}$ # $\newcommand{\testtest}[1]{#1_{**}}$
# $\newcommand{\dd}{\rm{d}}$ # $\newcommand{\dd}{{\rm{d}}}$
# $\newcommand{\lt}[1]{_{\text{#1}}}$
# $\DeclareMathOperator{\diag}{diag}$ # $\DeclareMathOperator{\diag}{diag}$
# $\DeclareMathOperator{\cov}{cov}$ # $\DeclareMathOperator{\cov}{cov}$
...@@ -187,14 +188,31 @@ likelihood.noise_covar.noise = 0.3 ...@@ -187,14 +188,31 @@ likelihood.noise_covar.noise = 0.3
# # Fit GP to data: optimize hyper params # # Fit GP to data: optimize hyper params
# #
# Now we optimize the GP hyper parameters by doing a GP-specific variational inference (VI), # Now we optimize the GP hyper parameters by doing a GP-specific variational
# where we optimize not the log marginal likelihood (ExactGP case), # inference (VI), where we don't maximize the log marginal likelihood (ExactGP
# but an ELBO (evidence lower bound) objective. The latter is a proxy for minimizing # case), but an ELBO ("evidence lower bound") objective -- a lower bound on the
# the KL divergence between distributions, which in our case are the approximate # marginal likelihood (the "evidence"). In variational inference, an ELBO objective
# shows up when minimizing the KL divergence between
# an approximate and the true posterior
#
# $$
# p(w|y) = \frac{p(y|w)\,p(w)}{\int p(y|w)\,p(w)\,\dd w}
# = \frac{p(y|w)\,p(w)}{p(y)}
# $$
#
# $$
# \ve\zeta^* = \text{arg}\min_{\ve\zeta} D\lt{KL}(q_{\ve\zeta}(w)\,\Vert\, p(w|y))
# $$
#
# to obtain the optimal variational parameters $\ve\zeta^*$ to approximate
# $p(w|y)$ with $q_{\ve\zeta^*}(w)$.
#
# In our case the two distributions are the approximate
# #
# $$q_{\ve\zeta}(\mathbf f)=\int p(\mathbf f|\ve u)\,q_{\ve\psi}(\ve u)\,\dd\ve u\quad(\text{"variational strategy"})$$ # $$q_{\ve\zeta}(\mathbf f)=\int p(\mathbf f|\ve u)\,q_{\ve\psi}(\ve u)\,\dd\ve u\quad(\text{"variational strategy"})$$
# #
# and the true $p(\mathbf f|\mathcal D)$ posterior over function values. We optimize with respect to # and the true $p(\mathbf f|\mathcal D)$ posterior over function values. We
# optimize with respect to
# #
# $$\ve\zeta = [\ell, \sigma_n^2, s, c, \ve\psi] $$ # $$\ve\zeta = [\ell, \sigma_n^2, s, c, \ve\psi] $$
# #
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment