diff --git a/BLcourse2.3/gp_intro.py b/BLcourse2.3/gp_intro.py index 4b290e407823f62e0a17967fa78dcb477b780bb4..6503d40b6032bc03818ba51236d7a82c2eac4e7b 100644 --- a/BLcourse2.3/gp_intro.py +++ b/BLcourse2.3/gp_intro.py @@ -324,7 +324,10 @@ with torch.no_grad(): post_pred_y = likelihood(model(X_pred)) fig, axs = plt.subplots(ncols=2, figsize=(12, 5)) - for ii, (ax, post_pred) in enumerate(zip(axs, [post_pred_f, post_pred_y])): + fig_sigmas, ax_sigmas = plt.subplots() + for ii, (ax, post_pred, name) in enumerate( + zip(axs, [post_pred_f, post_pred_y], ["f", "y"]) + ): yf_mean = post_pred.mean yf_samples = post_pred.sample(sample_shape=torch.Size((10,))) @@ -354,6 +357,23 @@ with torch.no_grad(): color="tab:orange", alpha=0.3, ) + if name == "f": + sigma_label = r"$\pm 2\sqrt{\mathrm{diag}(\Sigma)}$" + zorder = 1 + else: + sigma_label = ( + r"$\pm 2\sqrt{\mathrm{diag}(\Sigma + \sigma_n^2\,I)}$" + ) + zorder = 0 + ax_sigmas.fill_between( + X_pred.numpy(), + lower.numpy(), + upper.numpy(), + label="confidence " + sigma_label, + color="tab:orange" if name == "f" else "tab:blue", + alpha=0.5, + zorder=zorder, + ) y_min = y_train.min() y_max = y_train.max() y_span = y_max - y_min @@ -361,6 +381,7 @@ with torch.no_grad(): plot_samples(ax, X_pred, yf_samples, label="posterior pred. samples") if ii == 1: ax.legend() + ax_sigmas.legend() # When running as script if not is_interactive():