Skip to content
Snippets Groups Projects
Commit 83d58baa authored by Steve Schmerler's avatar Steve Schmerler
Browse files

gp: 01_one_dim: update text

parent a6091eb1
Branches
No related tags found
1 merge request!2Update GP slides and notebooks
...@@ -214,11 +214,12 @@ with torch.no_grad(): ...@@ -214,11 +214,12 @@ with torch.no_grad():
# - # -
# Let's investigate the samples more closely. A constant mean $\ve m(\ma X) = # Let's investigate the samples more closely. First we note that the samples
# \ve c$ does *not* mean that each sampled vector $\predve f$'s mean is # fluctuate around the mean `model.mean_module.constant` we defined above. A
# equal to $c$. Instead, we have that at each $\ve x_i$, the mean of # constant mean $\ve m(\ma X) = \ve c$ does *not* mean that each sampled vector
# *all* sampled functions is the same, so $\frac{1}{M}\sum_{j=1}^M f_j(\ve x_i) # $\predve f$'s mean is equal to $c$. Instead, we have that at each $\ve x_i$,
# \approx c$ and for $M\rightarrow\infty$ it will be exactly $c$. # the mean of *all* sampled functions is the same, so $\frac{1}{M}\sum_{j=1}^M
# f_j(\ve x_i) \approx c$ and for $M\rightarrow\infty$ it will be exactly $c$.
# #
# Look at the first 20 x points from M=10 samples # Look at the first 20 x points from M=10 samples
...@@ -242,7 +243,7 @@ print(f"{f_samples.mean(axis=0).std()=}") ...@@ -242,7 +243,7 @@ print(f"{f_samples.mean(axis=0).std()=}")
# #
# We use the fixed hyper param values defined above. In particular, since # We use the fixed hyper param values defined above. In particular, since
# $\sigma_n^2$ = `model.likelihood.noise_covar.noise` is > 0, we have a # $\sigma_n^2$ = `model.likelihood.noise_covar.noise` is > 0, we have a
# regression setting. # regression setting -- the GP's mean doesn't interpolate all points.
# + # +
# Evaluation (predictive posterior) mode # Evaluation (predictive posterior) mode
...@@ -297,20 +298,23 @@ with torch.no_grad(): ...@@ -297,20 +298,23 @@ with torch.no_grad():
ax.legend() ax.legend()
# - # -
# We observe that all sampled functions (green) and the mean (red) tend towards
# the low fixed mean function $m(\ve x)=c$ at 3.0 in the absence of data, while
# the actual data mean is `const` from above (data generation). Also the other
# hyper params ($\ell$, $\sigma_n^2$, $s$) are just guesses. Now we will
# calculate their actual value by minimizing the negative log marginal
# likelihood.
# # Fit GP to data: optimize hyper params # # Fit GP to data: optimize hyper params
# #
# In each step of the optimizer, we condition on the training data (e.g. do # In each step of the optimizer, we condition on the training data (e.g. do
# Bayesian inference) to calculate the posterior predictive distribution for # Bayesian inference) to calculate the posterior predictive distribution for
# the current values of the hyper params. We iterate until the log marginal # the current values of the hyper params. We iterate until the negative log marginal
# likelihood is converged. # likelihood is converged.
# #
# We use a simplistic PyTorch-style hand written train loop without convergence # We use a simplistic PyTorch-style hand written train loop without convergence
# control, so make sure to use enough `n_iter` and eyeball-check that the loss # control, so make sure to use enough `n_iter` and eyeball-check that the loss
# is converged :-) # is converged :-)
#
# Observe how all hyper params converge. In particular, note that the constant
# mean $m(\ve x)=c$ converges to the `const` value in `generate_data()`.
# + # +
# Train mode # Train mode
...@@ -334,7 +338,7 @@ for ii in range(n_iter): ...@@ -334,7 +338,7 @@ for ii in range(n_iter):
history["loss"].append(loss.item()) history["loss"].append(loss.item())
# - # -
# Plot hyper params and loss (neg. log marginal likelihood) convergence # Plot hyper params and loss (negative log marginal likelihood) convergence
ncols = len(history) ncols = len(history)
fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5)) fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5))
for ax, (p_name, p_lst) in zip(axs, history.items()): for ax, (p_name, p_lst) in zip(axs, history.items()):
...@@ -345,6 +349,9 @@ for ax, (p_name, p_lst) in zip(axs, history.items()): ...@@ -345,6 +349,9 @@ for ax, (p_name, p_lst) in zip(axs, history.items()):
# Values of optimized hyper params # Values of optimized hyper params
pprint(extract_model_params(model, raw=False)) pprint(extract_model_params(model, raw=False))
# We see that all hyper params converge. In particular, note that the constant
# mean $m(\ve x)=c$ converges to the `const` value in `generate_data()`.
# # Run prediction # # Run prediction
# #
# We show "noiseless" (left: $\sigma = \sqrt{\mathrm{diag}(\ma\Sigma)}$) vs. # We show "noiseless" (left: $\sigma = \sqrt{\mathrm{diag}(\ma\Sigma)}$) vs.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment