diff --git a/BLcourse2.3/utils.py b/BLcourse2.3/utils.py index fc512d45c86aeb380087bd297aed0b3b29d84d7b..c8a66c4cf4dd33fa5cf62d93e84287eabf33a405 100644 --- a/BLcourse2.3/utils.py +++ b/BLcourse2.3/utils.py @@ -1,7 +1,8 @@ from matplotlib import pyplot as plt +import torch -def extract_model_params(model, raw=False) -> dict: +def extract_model_params(model, raw=False, try_item=True) -> dict: """Helper to convert model.named_parameters() to dict. With raw=True, use @@ -11,9 +12,24 @@ def extract_model_params(model, raw=False) -> dict: See https://docs.gpytorch.ai/en/stable/examples/00_Basic_Usage/Hyperparameters.html#Raw-vs-Actual-Parameters """ + if try_item: + def get_val(p): + if isinstance(p, torch.Tensor): + if p.ndim == 0: + return p.item() + else: + p_sq = p.squeeze() + if (p_sq.ndim == 1 and len(p_sq) == 1) or p_sq.ndim == 0: + return p_sq.item() + else: + return p + else: + return p + else: + get_val = lambda p: p if raw: return dict( - (p_name, p_val.item()) + (p_name, get_val(p_val)) for p_name, p_val in model.named_parameters() ) else: @@ -24,7 +40,7 @@ def extract_model_params(model, raw=False) -> dict: # Yes, eval() hack. Sorry. p_name = p_name.replace(".raw_", ".") p_val = eval(f"model.{p_name}") - out[p_name] = p_val.item() + out[p_name] = get_val(p_val) return out