Skip to content
Snippets Groups Projects
Commit 8d16f2f1 authored by Steve Schmerler's avatar Steve Schmerler
Browse files

Draft data scaling in 02_two_dim

parent a163fc69
Branches
No related tags found
1 merge request!1BLcourse2.3Add GP part
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# name: python3 # name: python3
# --- # ---
# %matplotlib notebook # +
# %matplotlib widget
# -
# + # +
from collections import defaultdict from collections import defaultdict
...@@ -23,21 +25,23 @@ import torch ...@@ -23,21 +25,23 @@ import torch
import gpytorch import gpytorch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib import is_interactive from matplotlib import is_interactive
##from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import StandardScaler
from utils import extract_model_params, plot_samples, ExactGPModel from utils import extract_model_params, plot_samples, ExactGPModel
# -
# +
torch.set_default_dtype(torch.float64) torch.set_default_dtype(torch.float64)
torch.manual_seed(123) torch.manual_seed(123)
# -
# # Generate toy 2D data # # Generate toy 2D data
# + # +
class SurfaceData: class MexicanHat:
def __init__(self, xlim, ylim, nx, ny, mode, **kwds): def __init__(self, xlim, ylim, nx, ny, mode, **kwds):
self.xlim = xlim self.xlim = xlim
self.ylim = ylim self.ylim = ylim
...@@ -46,7 +50,7 @@ class SurfaceData: ...@@ -46,7 +50,7 @@ class SurfaceData:
self.xg, self.yg = self.get_xy_grid() self.xg, self.yg = self.get_xy_grid()
self.XG, self.YG = torch.meshgrid(self.xg, self.yg, indexing="ij") self.XG, self.YG = torch.meshgrid(self.xg, self.yg, indexing="ij")
self.X = self.make_X(mode) self.X = self.make_X(mode)
self.z = self.generate(self.X, **kwds) self.z = self.func(self.X)
def make_X(self, mode="grid"): def make_X(self, mode="grid"):
if mode == "grid": if mode == "grid":
...@@ -64,46 +68,38 @@ class SurfaceData: ...@@ -64,46 +68,38 @@ class SurfaceData:
y = torch.linspace(self.ylim[0], self.ylim[1], self.ny) y = torch.linspace(self.ylim[0], self.ylim[1], self.ny)
return x, y return x, y
def func(self, X): @staticmethod
raise NotImplementedError def func(X):
def generate(self, *args, **kwds):
if "der" in kwds:
der = kwds["der"]
kwds.pop("der")
if der == "x":
return self.deriv_x(*args, **kwds)
elif der == "y":
return self.deriv_y(*args, **kwds)
else:
raise Exception("der != 'x' or 'y'")
else:
return self.func(*args, **kwds)
class MexicanHat(SurfaceData):
def func(self, X):
r = torch.sqrt((X**2).sum(axis=1)) r = torch.sqrt((X**2).sum(axis=1))
return torch.sin(r) / r return torch.sin(r) / r * 10
def deriv_x(self, X): @staticmethod
r = torch.sqrt((X**2).sum(axis=1)) def n2t(x):
x = X[:, 0] return torch.from_numpy(x)
return x * torch.cos(r) / r**2 - x * torch.sin(r) / r**3.0
def deriv_y(self, X): def apply_scalers(self, x_scaler, y_scaler):
r = torch.sqrt((X**2).sum(axis=1)) self.X = self.n2t(x_scaler.transform(self.X))
y = X[:, 1] Xg = x_scaler.transform(torch.stack((self.xg, self.yg), dim=1))
return y * torch.cos(r) / r**2 - y * torch.sin(r) / r**3.0 self.xg = self.n2t(Xg[:, 0])
self.yg = self.n2t(Xg[:, 1])
self.XG, self.YG = torch.meshgrid(self.xg, self.yg, indexing="ij")
self.z = self.n2t(y_scaler.transform(self.z[:, None])[:, 0])
# -
# + # +
data_train = MexicanHat( data_train = MexicanHat(
xlim=[-15, 5], ylim=[-15, 25], nx=20, ny=20, mode="rand" xlim=[-15, 5], ylim=[-15, 25], nx=20, ny=20, mode="rand"
) )
x_scaler = StandardScaler().fit(data_train.X)
y_scaler = StandardScaler().fit(data_train.z[:, None])
data_train.apply_scalers(x_scaler, y_scaler)
data_pred = MexicanHat( data_pred = MexicanHat(
xlim=[-15, 5], ylim=[-15, 25], nx=100, ny=100, mode="grid" xlim=[-15, 5], ylim=[-15, 25], nx=100, ny=100, mode="grid"
) )
data_pred.apply_scalers(x_scaler, y_scaler)
# train inputs # train inputs
X_train = data_train.X X_train = data_train.X
...@@ -112,10 +108,11 @@ X_train = data_train.X ...@@ -112,10 +108,11 @@ X_train = data_train.X
X_pred = data_pred.X X_pred = data_pred.X
# noise-free train data # noise-free train data
##y_train = data_train.z y_train = data_train.z
# noisy train data # noisy train data
y_train = data_train.z + torch.randn(size=data_train.z.shape) / 20 ##y_train = data_train.z + torch.randn(size=data_train.z.shape) / 5
# -
# + # +
# Cut out part of the train data to create out-of-distribution predictions # Cut out part of the train data to create out-of-distribution predictions
...@@ -154,21 +151,24 @@ pprint(extract_model_params(model, raw=False)) ...@@ -154,21 +151,24 @@ pprint(extract_model_params(model, raw=False))
# + # +
# Set new start hyper params # Set new start hyper params
model.mean_module.constant = 0.0 model.mean_module.constant = 0.0
model.covar_module.base_kernel.lengthscale = 1.0 model.covar_module.base_kernel.lengthscale = 3.0
model.covar_module.outputscale = 1.0 model.covar_module.outputscale = 8.0
model.likelihood.noise_covar.noise = 0.1 model.likelihood.noise_covar.noise = 0.1
pprint(extract_model_params(model, raw=False)) pprint(extract_model_params(model, raw=False))
# -
# # Fit GP to data: optimize hyper params
# + # +
# Train mode # Train mode
model.train() model.train()
likelihood.train() likelihood.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) optimizer = torch.optim.Adam(model.parameters(), lr=0.2)
loss_func = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) loss_func = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
n_iter = 200 n_iter = 300
history = defaultdict(list) history = defaultdict(list)
for ii in range(n_iter): for ii in range(n_iter):
optimizer.zero_grad() optimizer.zero_grad()
...@@ -182,15 +182,19 @@ for ii in range(n_iter): ...@@ -182,15 +182,19 @@ for ii in range(n_iter):
history["loss"].append(loss.item()) history["loss"].append(loss.item())
# - # -
# +
ncols = len(history) ncols = len(history)
fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5)) fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5))
for ax, (p_name, p_lst) in zip(axs, history.items()): for ax, (p_name, p_lst) in zip(axs, history.items()):
ax.plot(p_lst) ax.plot(p_lst)
ax.set_title(p_name) ax.set_title(p_name)
ax.set_xlabel("iterations") ax.set_xlabel("iterations")
# -
# +
# Values of optimized hyper params # Values of optimized hyper params
pprint(extract_model_params(model, raw=False)) pprint(extract_model_params(model, raw=False))
# -
# # Run prediction # # Run prediction
...@@ -221,11 +225,17 @@ with torch.no_grad(): ...@@ -221,11 +225,17 @@ with torch.no_grad():
ax.set_xlabel("X_0") ax.set_xlabel("X_0")
ax.set_ylabel("X_1") ax.set_ylabel("X_1")
assert (post_pred_f.mean == post_pred_y.mean).all()
# -
# # Plot difference to ground truth and uncertainty # # Plot difference to ground truth and uncertainty
# + # +
ncols = 2 ncols = 3
fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5)) fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 7, 5))
vmax = post_pred_y.stddev.max()
c0 = axs[0].contourf( c0 = axs[0].contourf(
data_pred.XG, data_pred.XG,
...@@ -235,19 +245,31 @@ c0 = axs[0].contourf( ...@@ -235,19 +245,31 @@ c0 = axs[0].contourf(
), ),
) )
axs[0].set_title("|y_pred - y_true|") axs[0].set_title("|y_pred - y_true|")
c1 = axs[1].contourf( c1 = axs[1].contourf(
data_pred.XG,
data_pred.YG,
post_pred_f.stddev.reshape((data_pred.nx, data_pred.ny)),
vmax=vmax,
)
axs[1].set_title("f_std (epistemic)")
c2 = axs[2].contourf(
data_pred.XG, data_pred.XG,
data_pred.YG, data_pred.YG,
post_pred_y.stddev.reshape((data_pred.nx, data_pred.ny)), post_pred_y.stddev.reshape((data_pred.nx, data_pred.ny)),
vmax=vmax,
) )
axs[1].set_title("y_std") axs[2].set_title("y_std (epistemic + aleatoric)")
for ax, c in zip(axs, [c0, c1]): for ax, c in zip(axs, [c0, c1, c2]):
ax.set_xlabel("X_0") ax.set_xlabel("X_0")
ax.set_ylabel("X_1") ax.set_ylabel("X_1")
ax.scatter(x=X_train[:, 0], y=X_train[:, 1], color="white", alpha=0.2) ax.scatter(x=X_train[:, 0], y=X_train[:, 1], color="white", alpha=0.2)
fig.colorbar(c, ax=ax) fig.colorbar(c, ax=ax)
# When running as script # When running as script
if not is_interactive(): if not is_interactive():
plt.show() plt.show()
# -
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment