from matplotlib import pyplot as plt import torch def extract_model_params(model, raw=False, try_item=False) -> dict: """Helper to convert model.named_parameters() to dict. With raw=True, use foo.bar.raw_param else foo.bar.param See https://docs.gpytorch.ai/en/stable/examples/00_Basic_Usage/Hyperparameters.html#Raw-vs-Actual-Parameters """ if try_item: def get_val(p): if isinstance(p, torch.Tensor): if p.ndim == 0: return p.item() else: p_sq = p.squeeze() if (p_sq.ndim == 1 and len(p_sq) == 1) or p_sq.ndim == 0: return p_sq.item() else: return p else: return p else: get_val = lambda p: p if raw: return dict( (p_name, get_val(p_val)) for p_name, p_val in model.named_parameters() ) else: out = dict() # p_name = 'covar_module.base_kernel.raw_lengthscale'. Access # model.covar_module.base_kernel.lengthscale (w/o the raw_) for p_name, p_val in model.named_parameters(): # Yes, eval() hack. Sorry. p_name = p_name.replace(".raw_", ".") p_val = eval(f"model.{p_name}") out[p_name] = get_val(p_val) return out def fig_ax_3d(): fig = plt.figure() ax = fig.add_subplot(projection="3d") return fig, ax def plot_samples(ax, X_pred, samples, label=None, **kwds): plot_kwds = dict(color="tab:green", alpha=0.3) plot_kwds.update(kwds) if label is None: ax.plot(X_pred, samples.T, **plot_kwds) else: ax.plot(X_pred, samples[0, :], **plot_kwds, label=label) ax.plot(X_pred, samples[1:, :].T, **plot_kwds, label="_")