Skip to content
Snippets Groups Projects
Commit 49cfb457 authored by Steve Schmerler's avatar Steve Schmerler
Browse files

utils: make extract_model_params handle more tensors

parent 279916d3
No related branches found
No related tags found
1 merge request!2Update GP slides and notebooks
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import torch
def extract_model_params(model, raw=False) -> dict: def extract_model_params(model, raw=False, try_item=True) -> dict:
"""Helper to convert model.named_parameters() to dict. """Helper to convert model.named_parameters() to dict.
With raw=True, use With raw=True, use
...@@ -11,9 +12,24 @@ def extract_model_params(model, raw=False) -> dict: ...@@ -11,9 +12,24 @@ def extract_model_params(model, raw=False) -> dict:
See https://docs.gpytorch.ai/en/stable/examples/00_Basic_Usage/Hyperparameters.html#Raw-vs-Actual-Parameters See https://docs.gpytorch.ai/en/stable/examples/00_Basic_Usage/Hyperparameters.html#Raw-vs-Actual-Parameters
""" """
if try_item:
def get_val(p):
if isinstance(p, torch.Tensor):
if p.ndim == 0:
return p.item()
else:
p_sq = p.squeeze()
if (p_sq.ndim == 1 and len(p_sq) == 1) or p_sq.ndim == 0:
return p_sq.item()
else:
return p
else:
return p
else:
get_val = lambda p: p
if raw: if raw:
return dict( return dict(
(p_name, p_val.item()) (p_name, get_val(p_val))
for p_name, p_val in model.named_parameters() for p_name, p_val in model.named_parameters()
) )
else: else:
...@@ -24,7 +40,7 @@ def extract_model_params(model, raw=False) -> dict: ...@@ -24,7 +40,7 @@ def extract_model_params(model, raw=False) -> dict:
# Yes, eval() hack. Sorry. # Yes, eval() hack. Sorry.
p_name = p_name.replace(".raw_", ".") p_name = p_name.replace(".raw_", ".")
p_val = eval(f"model.{p_name}") p_val = eval(f"model.{p_name}")
out[p_name] = p_val.item() out[p_name] = get_val(p_val)
return out return out
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment