Skip to content
Snippets Groups Projects
Commit 744a7045 authored by Steve Schmerler's avatar Steve Schmerler
Browse files

01_one_dim: use new extract_model_params() defaults

parent c8d7c137
No related branches found
No related tags found
1 merge request!2Update GP slides and notebooks
......@@ -171,7 +171,7 @@ model = ExactGPModel(X_train, y_train, likelihood)
print(model)
# Default start hyper params
pprint(extract_model_params(model, raw=False))
pprint(extract_model_params(model))
# +
# Set new start hyper params
......@@ -180,7 +180,7 @@ model.covar_module.base_kernel.lengthscale = 1.0
model.covar_module.outputscale = 1.0
model.likelihood.noise_covar.noise = 1e-3
pprint(extract_model_params(model, raw=False))
pprint(extract_model_params(model))
# -
......@@ -347,7 +347,7 @@ for ii in range(n_iter):
optimizer.step()
if (ii + 1) % 10 == 0:
print(f"iter {ii + 1}/{n_iter}, {loss=:.3f}")
for p_name, p_val in extract_model_params(model).items():
for p_name, p_val in extract_model_params(model, try_item=True).items():
history[p_name].append(p_val)
history["loss"].append(loss.item())
# -
......@@ -361,7 +361,7 @@ for ax, (p_name, p_lst) in zip(axs, history.items()):
ax.set_xlabel("iterations")
# Values of optimized hyper params
pprint(extract_model_params(model, raw=False))
pprint(extract_model_params(model))
# We see that all hyper params converge. In particular, note that the constant
# mean $m(\ve x)=c$ converges to the `const` value in `generate_data()`.
......@@ -489,7 +489,7 @@ print(
print(
"learned noise:",
np.sqrt(
extract_model_params(model, raw=False)["likelihood.noise_covar.noise"]
extract_model_params(model, try_item=True)["likelihood.noise_covar.noise"]
),
)
# -
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment