From 53c4b256ab8bfd98f18c37d64ced69f1e78d5082 Mon Sep 17 00:00:00 2001
From: Steve Schmerler <git@elcorto.com>
Date: Sun, 18 May 2025 10:26:33 +0200
Subject: [PATCH] Sync extract_model_params() usage

---
 BLcourse2.3/01_one_dim.py     |  4 +++-
 BLcourse2.3/02_two_dim.py     | 13 +++++++------
 BLcourse2.3/03_one_dim_SVI.py | 10 +++++-----
 BLcourse2.3/utils.py          |  1 +
 4 files changed, 16 insertions(+), 12 deletions(-)

diff --git a/BLcourse2.3/01_one_dim.py b/BLcourse2.3/01_one_dim.py
index ec72e4e..f29e5bb 100644
--- a/BLcourse2.3/01_one_dim.py
+++ b/BLcourse2.3/01_one_dim.py
@@ -500,7 +500,9 @@ print(
 print(
     "learned noise:",
     np.sqrt(
-        extract_model_params(model, try_item=True)["likelihood.noise_covar.noise"]
+        extract_model_params(model, try_item=True)[
+            "likelihood.noise_covar.noise"
+        ]
     ),
 )
 # -
diff --git a/BLcourse2.3/02_two_dim.py b/BLcourse2.3/02_two_dim.py
index 35311de..c5ab47b 100644
--- a/BLcourse2.3/02_two_dim.py
+++ b/BLcourse2.3/02_two_dim.py
@@ -28,7 +28,6 @@
 # $\newcommand{\testtest}[1]{#1_{**}}$
 # $\DeclareMathOperator{\diag}{diag}$
 # $\DeclareMathOperator{\cov}{cov}$
-#
 
 # +
 # ##%matplotlib notebook
@@ -239,7 +238,7 @@ model = ExactGPModel(X_train, y_train, likelihood)
 print(model)
 
 # Default start hyper params
-pprint(extract_model_params(model, raw=False))
+pprint(extract_model_params(model))
 
 # +
 # Set new start hyper params
@@ -248,7 +247,7 @@ model.covar_module.base_kernel.lengthscale = 3.0
 model.covar_module.outputscale = 8.0
 model.likelihood.noise_covar.noise = 0.1
 
-pprint(extract_model_params(model, raw=False))
+pprint(extract_model_params(model))
 # -
 
 # # Fit GP to data: optimize hyper params
@@ -270,7 +269,7 @@ for ii in range(n_iter):
     optimizer.step()
     if (ii + 1) % 10 == 0:
         print(f"iter {ii + 1}/{n_iter}, {loss=:.3f}")
-    for p_name, p_val in extract_model_params(model).items():
+    for p_name, p_val in extract_model_params(model, try_item=True).items():
         history[p_name].append(p_val)
     history["loss"].append(loss.item())
 # -
@@ -283,7 +282,7 @@ for ax, (p_name, p_lst) in zip(axs, history.items()):
     ax.set_xlabel("iterations")
 
 # Values of optimized hyper params
-pprint(extract_model_params(model, raw=False))
+pprint(extract_model_params(model))
 
 # # Run prediction
 
@@ -403,7 +402,9 @@ print(
 print(
     "learned noise:",
     np.sqrt(
-        extract_model_params(model, raw=False)["likelihood.noise_covar.noise"]
+        extract_model_params(model, try_item=True)[
+            "likelihood.noise_covar.noise"
+        ]
     ),
 )
 # -
diff --git a/BLcourse2.3/03_one_dim_SVI.py b/BLcourse2.3/03_one_dim_SVI.py
index 86d7073..8efb75d 100644
--- a/BLcourse2.3/03_one_dim_SVI.py
+++ b/BLcourse2.3/03_one_dim_SVI.py
@@ -171,9 +171,9 @@ print(likelihood)
 
 # Default start hyper params
 print("model params:")
-pprint(extract_model_params(model, raw=False, try_item=False))
+pprint(extract_model_params(model))
 print("likelihood params:")
-pprint(extract_model_params(likelihood, raw=False, try_item=False))
+pprint(extract_model_params(likelihood))
 
 # +
 # Set new start hyper params
@@ -258,9 +258,9 @@ with torch.no_grad():
 
 # Values of optimized hyper params
 print("model params:")
-pprint(extract_model_params(model, raw=False, try_item=False))
+pprint(extract_model_params(model))
 print("likelihood params:")
-pprint(extract_model_params(likelihood, raw=False, try_item=False))
+pprint(extract_model_params(likelihood))
 
 
 # # Run prediction
@@ -364,7 +364,7 @@ print(
 print(
     "learned noise:",
     np.sqrt(
-        extract_model_params(likelihood, raw=False)["noise_covar.noise"]
+        extract_model_params(likelihood, try_item=True)["noise_covar.noise"]
     ),
 )
 # -
diff --git a/BLcourse2.3/utils.py b/BLcourse2.3/utils.py
index 0c0e605..260be47 100644
--- a/BLcourse2.3/utils.py
+++ b/BLcourse2.3/utils.py
@@ -13,6 +13,7 @@ def extract_model_params(model, raw=False, try_item=False) -> dict:
     See https://docs.gpytorch.ai/en/stable/examples/00_Basic_Usage/Hyperparameters.html#Raw-vs-Actual-Parameters
     """
     if try_item:
+
         def get_val(p):
             if isinstance(p, torch.Tensor):
                 if p.ndim == 0:
-- 
GitLab