diff --git a/BLcourse2.3/02_two_dim.py b/BLcourse2.3/02_two_dim.py
index b8c9f2c23cd3e18cadcdff10d6c4449f3d309cf0..0e7333177178ed8e9380b0c0e1567f6965e1c343 100644
--- a/BLcourse2.3/02_two_dim.py
+++ b/BLcourse2.3/02_two_dim.py
@@ -13,7 +13,9 @@
 #     name: python3
 # ---
 
-# %matplotlib notebook
+# +
+# %matplotlib widget
+# -
 
 # +
 from collections import defaultdict
@@ -23,21 +25,23 @@ import torch
 import gpytorch
 from matplotlib import pyplot as plt
 from matplotlib import is_interactive
-##from mpl_toolkits.mplot3d import Axes3D
+from mpl_toolkits.mplot3d import Axes3D
 
+from sklearn.preprocessing import StandardScaler
 
 from utils import extract_model_params, plot_samples, ExactGPModel
-# -
 
 
+# +
 torch.set_default_dtype(torch.float64)
 torch.manual_seed(123)
+# -
 
 # # Generate toy 2D data
 
 
 # +
-class SurfaceData:
+class MexicanHat:
     def __init__(self, xlim, ylim, nx, ny, mode, **kwds):
         self.xlim = xlim
         self.ylim = ylim
@@ -46,7 +50,7 @@ class SurfaceData:
         self.xg, self.yg = self.get_xy_grid()
         self.XG, self.YG = torch.meshgrid(self.xg, self.yg, indexing="ij")
         self.X = self.make_X(mode)
-        self.z = self.generate(self.X, **kwds)
+        self.z = self.func(self.X)
 
     def make_X(self, mode="grid"):
         if mode == "grid":
@@ -64,46 +68,38 @@ class SurfaceData:
         y = torch.linspace(self.ylim[0], self.ylim[1], self.ny)
         return x, y
 
-    def func(self, X):
-        raise NotImplementedError
-
-    def generate(self, *args, **kwds):
-        if "der" in kwds:
-            der = kwds["der"]
-            kwds.pop("der")
-            if der == "x":
-                return self.deriv_x(*args, **kwds)
-            elif der == "y":
-                return self.deriv_y(*args, **kwds)
-            else:
-                raise Exception("der != 'x' or 'y'")
-        else:
-            return self.func(*args, **kwds)
-
-
-class MexicanHat(SurfaceData):
-    def func(self, X):
+    @staticmethod
+    def func(X):
         r = torch.sqrt((X**2).sum(axis=1))
-        return torch.sin(r) / r
+        return torch.sin(r) / r * 10
 
-    def deriv_x(self, X):
-        r = torch.sqrt((X**2).sum(axis=1))
-        x = X[:, 0]
-        return x * torch.cos(r) / r**2 - x * torch.sin(r) / r**3.0
+    @staticmethod
+    def n2t(x):
+        return torch.from_numpy(x)
 
-    def deriv_y(self, X):
-        r = torch.sqrt((X**2).sum(axis=1))
-        y = X[:, 1]
-        return y * torch.cos(r) / r**2 - y * torch.sin(r) / r**3.0
+    def apply_scalers(self, x_scaler, y_scaler):
+        self.X = self.n2t(x_scaler.transform(self.X))
+        Xg = x_scaler.transform(torch.stack((self.xg, self.yg), dim=1))
+        self.xg = self.n2t(Xg[:, 0])
+        self.yg = self.n2t(Xg[:, 1])
+        self.XG, self.YG = torch.meshgrid(self.xg, self.yg, indexing="ij")
+        self.z = self.n2t(y_scaler.transform(self.z[:, None])[:, 0])
 
 
+# -
+
 # +
 data_train = MexicanHat(
     xlim=[-15, 5], ylim=[-15, 25], nx=20, ny=20, mode="rand"
 )
+x_scaler = StandardScaler().fit(data_train.X)
+y_scaler = StandardScaler().fit(data_train.z[:, None])
+data_train.apply_scalers(x_scaler, y_scaler)
+
 data_pred = MexicanHat(
     xlim=[-15, 5], ylim=[-15, 25], nx=100, ny=100, mode="grid"
 )
+data_pred.apply_scalers(x_scaler, y_scaler)
 
 # train inputs
 X_train = data_train.X
@@ -112,10 +108,11 @@ X_train = data_train.X
 X_pred = data_pred.X
 
 # noise-free train data
-##y_train = data_train.z
+y_train = data_train.z
 
 # noisy train data
-y_train = data_train.z + torch.randn(size=data_train.z.shape) / 20
+##y_train = data_train.z + torch.randn(size=data_train.z.shape) / 5
+# -
 
 # +
 # Cut out part of the train data to create out-of-distribution predictions
@@ -154,21 +151,24 @@ pprint(extract_model_params(model, raw=False))
 # +
 # Set new start hyper params
 model.mean_module.constant = 0.0
-model.covar_module.base_kernel.lengthscale = 1.0
-model.covar_module.outputscale = 1.0
+model.covar_module.base_kernel.lengthscale = 3.0
+model.covar_module.outputscale = 8.0
 model.likelihood.noise_covar.noise = 0.1
 
 pprint(extract_model_params(model, raw=False))
+# -
+
+# # Fit GP to data: optimize hyper params
 
 # +
 # Train mode
 model.train()
 likelihood.train()
 
-optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
+optimizer = torch.optim.Adam(model.parameters(), lr=0.2)
 loss_func = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
 
-n_iter = 200
+n_iter = 300
 history = defaultdict(list)
 for ii in range(n_iter):
     optimizer.zero_grad()
@@ -182,15 +182,19 @@ for ii in range(n_iter):
     history["loss"].append(loss.item())
 # -
 
+# +
 ncols = len(history)
 fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5))
 for ax, (p_name, p_lst) in zip(axs, history.items()):
     ax.plot(p_lst)
     ax.set_title(p_name)
     ax.set_xlabel("iterations")
+# -
 
+# +
 # Values of optimized hyper params
 pprint(extract_model_params(model, raw=False))
+# -
 
 # # Run prediction
 
@@ -221,11 +225,17 @@ with torch.no_grad():
     ax.set_xlabel("X_0")
     ax.set_ylabel("X_1")
 
+assert (post_pred_f.mean == post_pred_y.mean).all()
+# -
+
+
 # # Plot difference to ground truth and uncertainty
 
 # +
-ncols = 2
-fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 5, 5))
+ncols = 3
+fig, axs = plt.subplots(ncols=ncols, nrows=1, figsize=(ncols * 7, 5))
+
+vmax = post_pred_y.stddev.max()
 
 c0 = axs[0].contourf(
     data_pred.XG,
@@ -235,19 +245,31 @@ c0 = axs[0].contourf(
     ),
 )
 axs[0].set_title("|y_pred - y_true|")
+
 c1 = axs[1].contourf(
+    data_pred.XG,
+    data_pred.YG,
+    post_pred_f.stddev.reshape((data_pred.nx, data_pred.ny)),
+    vmax=vmax,
+)
+axs[1].set_title("f_std (epistemic)")
+
+c2 = axs[2].contourf(
     data_pred.XG,
     data_pred.YG,
     post_pred_y.stddev.reshape((data_pred.nx, data_pred.ny)),
+    vmax=vmax,
 )
-axs[1].set_title("y_std")
+axs[2].set_title("y_std (epistemic + aleatoric)")
 
-for ax, c in zip(axs, [c0, c1]):
+for ax, c in zip(axs, [c0, c1, c2]):
     ax.set_xlabel("X_0")
     ax.set_ylabel("X_1")
     ax.scatter(x=X_train[:, 0], y=X_train[:, 1], color="white", alpha=0.2)
     fig.colorbar(c, ax=ax)
 
+
 # When running as script
 if not is_interactive():
     plt.show()
+# -