From 49cfb4575621a5db40fd67aca13a38344924e477 Mon Sep 17 00:00:00 2001 From: Steve Schmerler <git@elcorto.com> Date: Thu, 15 May 2025 20:07:10 +0200 Subject: [PATCH] utils: make extract_model_params handle more tensors --- BLcourse2.3/utils.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/BLcourse2.3/utils.py b/BLcourse2.3/utils.py index fc512d4..c8a66c4 100644 --- a/BLcourse2.3/utils.py +++ b/BLcourse2.3/utils.py @@ -1,7 +1,8 @@ from matplotlib import pyplot as plt +import torch -def extract_model_params(model, raw=False) -> dict: +def extract_model_params(model, raw=False, try_item=True) -> dict: """Helper to convert model.named_parameters() to dict. With raw=True, use @@ -11,9 +12,24 @@ def extract_model_params(model, raw=False) -> dict: See https://docs.gpytorch.ai/en/stable/examples/00_Basic_Usage/Hyperparameters.html#Raw-vs-Actual-Parameters """ + if try_item: + def get_val(p): + if isinstance(p, torch.Tensor): + if p.ndim == 0: + return p.item() + else: + p_sq = p.squeeze() + if (p_sq.ndim == 1 and len(p_sq) == 1) or p_sq.ndim == 0: + return p_sq.item() + else: + return p + else: + return p + else: + get_val = lambda p: p if raw: return dict( - (p_name, p_val.item()) + (p_name, get_val(p_val)) for p_name, p_val in model.named_parameters() ) else: @@ -24,7 +40,7 @@ def extract_model_params(model, raw=False) -> dict: # Yes, eval() hack. Sorry. p_name = p_name.replace(".raw_", ".") p_val = eval(f"model.{p_name}") - out[p_name] = p_val.item() + out[p_name] = get_val(p_val) return out -- GitLab