diff --git a/BLcourse2.3/01_one_dim.py b/BLcourse2.3/01_one_dim.py index 9043f2aaa0378e9db4bfec4eca37f61e3d5d5dc7..66bafc5b89f8a4b4f30736d69ac18a0f153a31a6 100644 --- a/BLcourse2.3/01_one_dim.py +++ b/BLcourse2.3/01_one_dim.py @@ -354,11 +354,14 @@ for ii in range(n_iter): # Plot hyper params and loss (negative log marginal likelihood) convergence ncols = len(history) -fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5)) -for ax, (p_name, p_lst) in zip(axs, history.items()): - ax.plot(p_lst) - ax.set_title(p_name) - ax.set_xlabel("iterations") +fig, axs = plt.subplots( + ncols=ncols, nrows=1, figsize=(ncols * 3, 3), layout="compressed" +) +with torch.no_grad(): + for ax, (p_name, p_lst) in zip(axs, history.items()): + ax.plot(p_lst) + ax.set_title(p_name) + ax.set_xlabel("iterations") # Values of optimized hyper params pprint(extract_model_params(model)) @@ -387,7 +390,7 @@ with torch.no_grad(): post_pred_f = model(X_pred) post_pred_y = likelihood(model(X_pred)) - fig, axs = plt.subplots(ncols=2, figsize=(12, 5), sharex=True, sharey=True) + fig, axs = plt.subplots(ncols=2, figsize=(14, 5), sharex=True, sharey=True) fig_sigmas, ax_sigmas = plt.subplots() for ii, (ax, post_pred, name, title) in enumerate( zip(