From 2e1e002d3c35df97ec30df2244f2fae271eec29d Mon Sep 17 00:00:00 2001 From: Steve Schmerler <git@elcorto.com> Date: Thu, 8 May 2025 12:28:26 +0200 Subject: [PATCH] gp: 01_one_dim: plot ground truth --- BLcourse2.3/01_one_dim.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/BLcourse2.3/01_one_dim.py b/BLcourse2.3/01_one_dim.py index f9f8f95..1b175af 100644 --- a/BLcourse2.3/01_one_dim.py +++ b/BLcourse2.3/01_one_dim.py @@ -77,26 +77,35 @@ torch.manual_seed(123) # + +def ground_truth(x, const): + return torch.sin(x) * torch.exp(-0.2 * x) + const + + def generate_data(x, gaps=[[1, 3]], const=5): - y = torch.sin(x) * torch.exp(-0.2 * x) + torch.randn(x.shape) * 0.1 + const + y = ground_truth(x, const=const) + torch.randn(x.shape) * 0.1 msk = torch.tensor([True] * len(x)) if gaps is not None: for g in gaps: msk = msk & ~((x > g[0]) & (x < g[1])) - return x[msk], y[msk] + return x[msk], y[msk], y +const = 5.0 x = torch.linspace(0, 4 * math.pi, 100) -X_train, y_train = generate_data(x, gaps=[[6, 10]]) +X_train, y_train, y_gt_train = generate_data(x, gaps=[[6, 10]], const=const) X_pred = torch.linspace( X_train[0] - 2, X_train[-1] + 2, 200, requires_grad=False ) +y_gt_pred = ground_truth(X_pred, const=const) print(f"{X_train.shape=}") print(f"{y_train.shape=}") print(f"{X_pred.shape=}") -plt.scatter(X_train, y_train, marker="o", color="tab:blue") +fig, ax = plt.subplots() +ax.scatter(X_train, y_train, marker="o", color="tab:blue", label="noisy data") +ax.plot(X_pred, y_gt_pred, ls="--", color="k", label="ground truth") +ax.legend() # - # # Define GP model @@ -264,6 +273,14 @@ with torch.no_grad(): color="tab:red", lw=2, ) + ax.plot( + X_pred.numpy(), + y_gt_pred.numpy(), + label="ground truth", + color="k", + lw=2, + ls="--", + ) ax.fill_between( X_pred.numpy(), lower.numpy(), @@ -377,6 +394,14 @@ with torch.no_grad(): color="tab:red", lw=2, ) + ax.plot( + X_pred.numpy(), + y_gt_pred.numpy(), + label="ground truth", + color="k", + lw=2, + ls="--", + ) ax.fill_between( X_pred.numpy(), lower.numpy(), -- GitLab