Skip to content
Snippets Groups Projects
Commit 2e1e002d authored by Steve Schmerler's avatar Steve Schmerler
Browse files

gp: 01_one_dim: plot ground truth

parent fd9943cb
Branches
No related tags found
1 merge request!2Update GP slides and notebooks
......@@ -77,26 +77,35 @@ torch.manual_seed(123)
# +
def ground_truth(x, const):
return torch.sin(x) * torch.exp(-0.2 * x) + const
def generate_data(x, gaps=[[1, 3]], const=5):
y = torch.sin(x) * torch.exp(-0.2 * x) + torch.randn(x.shape) * 0.1 + const
y = ground_truth(x, const=const) + torch.randn(x.shape) * 0.1
msk = torch.tensor([True] * len(x))
if gaps is not None:
for g in gaps:
msk = msk & ~((x > g[0]) & (x < g[1]))
return x[msk], y[msk]
return x[msk], y[msk], y
const = 5.0
x = torch.linspace(0, 4 * math.pi, 100)
X_train, y_train = generate_data(x, gaps=[[6, 10]])
X_train, y_train, y_gt_train = generate_data(x, gaps=[[6, 10]], const=const)
X_pred = torch.linspace(
X_train[0] - 2, X_train[-1] + 2, 200, requires_grad=False
)
y_gt_pred = ground_truth(X_pred, const=const)
print(f"{X_train.shape=}")
print(f"{y_train.shape=}")
print(f"{X_pred.shape=}")
plt.scatter(X_train, y_train, marker="o", color="tab:blue")
fig, ax = plt.subplots()
ax.scatter(X_train, y_train, marker="o", color="tab:blue", label="noisy data")
ax.plot(X_pred, y_gt_pred, ls="--", color="k", label="ground truth")
ax.legend()
# -
# # Define GP model
......@@ -264,6 +273,14 @@ with torch.no_grad():
color="tab:red",
lw=2,
)
ax.plot(
X_pred.numpy(),
y_gt_pred.numpy(),
label="ground truth",
color="k",
lw=2,
ls="--",
)
ax.fill_between(
X_pred.numpy(),
lower.numpy(),
......@@ -377,6 +394,14 @@ with torch.no_grad():
color="tab:red",
lw=2,
)
ax.plot(
X_pred.numpy(),
y_gt_pred.numpy(),
label="ground truth",
color="k",
lw=2,
ls="--",
)
ax.fill_between(
X_pred.numpy(),
lower.numpy(),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment