Skip to content
Snippets Groups Projects
Commit 53c4b256 authored by Steve Schmerler's avatar Steve Schmerler
Browse files

Sync extract_model_params() usage

parent 933f4591
Branches
No related tags found
1 merge request!2Update GP slides and notebooks
......@@ -500,7 +500,9 @@ print(
print(
"learned noise:",
np.sqrt(
extract_model_params(model, try_item=True)["likelihood.noise_covar.noise"]
extract_model_params(model, try_item=True)[
"likelihood.noise_covar.noise"
]
),
)
# -
......
......@@ -28,7 +28,6 @@
# $\newcommand{\testtest}[1]{#1_{**}}$
# $\DeclareMathOperator{\diag}{diag}$
# $\DeclareMathOperator{\cov}{cov}$
#
# +
# ##%matplotlib notebook
......@@ -239,7 +238,7 @@ model = ExactGPModel(X_train, y_train, likelihood)
print(model)
# Default start hyper params
pprint(extract_model_params(model, raw=False))
pprint(extract_model_params(model))
# +
# Set new start hyper params
......@@ -248,7 +247,7 @@ model.covar_module.base_kernel.lengthscale = 3.0
model.covar_module.outputscale = 8.0
model.likelihood.noise_covar.noise = 0.1
pprint(extract_model_params(model, raw=False))
pprint(extract_model_params(model))
# -
# # Fit GP to data: optimize hyper params
......@@ -270,7 +269,7 @@ for ii in range(n_iter):
optimizer.step()
if (ii + 1) % 10 == 0:
print(f"iter {ii + 1}/{n_iter}, {loss=:.3f}")
for p_name, p_val in extract_model_params(model).items():
for p_name, p_val in extract_model_params(model, try_item=True).items():
history[p_name].append(p_val)
history["loss"].append(loss.item())
# -
......@@ -283,7 +282,7 @@ for ax, (p_name, p_lst) in zip(axs, history.items()):
ax.set_xlabel("iterations")
# Values of optimized hyper params
pprint(extract_model_params(model, raw=False))
pprint(extract_model_params(model))
# # Run prediction
......@@ -403,7 +402,9 @@ print(
print(
"learned noise:",
np.sqrt(
extract_model_params(model, raw=False)["likelihood.noise_covar.noise"]
extract_model_params(model, try_item=True)[
"likelihood.noise_covar.noise"
]
),
)
# -
......
......@@ -171,9 +171,9 @@ print(likelihood)
# Default start hyper params
print("model params:")
pprint(extract_model_params(model, raw=False, try_item=False))
pprint(extract_model_params(model))
print("likelihood params:")
pprint(extract_model_params(likelihood, raw=False, try_item=False))
pprint(extract_model_params(likelihood))
# +
# Set new start hyper params
......@@ -258,9 +258,9 @@ with torch.no_grad():
# Values of optimized hyper params
print("model params:")
pprint(extract_model_params(model, raw=False, try_item=False))
pprint(extract_model_params(model))
print("likelihood params:")
pprint(extract_model_params(likelihood, raw=False, try_item=False))
pprint(extract_model_params(likelihood))
# # Run prediction
......@@ -364,7 +364,7 @@ print(
print(
"learned noise:",
np.sqrt(
extract_model_params(likelihood, raw=False)["noise_covar.noise"]
extract_model_params(likelihood, try_item=True)["noise_covar.noise"]
),
)
# -
......
......@@ -13,6 +13,7 @@ def extract_model_params(model, raw=False, try_item=False) -> dict:
See https://docs.gpytorch.ai/en/stable/examples/00_Basic_Usage/Hyperparameters.html#Raw-vs-Actual-Parameters
"""
if try_item:
def get_val(p):
if isinstance(p, torch.Tensor):
if p.ndim == 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment