From f4542a4664452f8d18b2753b52160ef605a5b496 Mon Sep 17 00:00:00 2001 From: Steve Schmerler <git@elcorto.com> Date: Sat, 17 May 2025 10:10:38 +0200 Subject: [PATCH] 01_one_dim: better plotting layout --- BLcourse2.3/01_one_dim.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/BLcourse2.3/01_one_dim.py b/BLcourse2.3/01_one_dim.py index 9043f2a..66bafc5 100644 --- a/BLcourse2.3/01_one_dim.py +++ b/BLcourse2.3/01_one_dim.py @@ -354,11 +354,14 @@ for ii in range(n_iter): # Plot hyper params and loss (negative log marginal likelihood) convergence ncols = len(history) -fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5)) -for ax, (p_name, p_lst) in zip(axs, history.items()): - ax.plot(p_lst) - ax.set_title(p_name) - ax.set_xlabel("iterations") +fig, axs = plt.subplots( + ncols=ncols, nrows=1, figsize=(ncols * 3, 3), layout="compressed" +) +with torch.no_grad(): + for ax, (p_name, p_lst) in zip(axs, history.items()): + ax.plot(p_lst) + ax.set_title(p_name) + ax.set_xlabel("iterations") # Values of optimized hyper params pprint(extract_model_params(model)) @@ -387,7 +390,7 @@ with torch.no_grad(): post_pred_f = model(X_pred) post_pred_y = likelihood(model(X_pred)) - fig, axs = plt.subplots(ncols=2, figsize=(12, 5), sharex=True, sharey=True) + fig, axs = plt.subplots(ncols=2, figsize=(14, 5), sharex=True, sharey=True) fig_sigmas, ax_sigmas = plt.subplots() for ii, (ax, post_pred, name, title) in enumerate( zip( -- GitLab