diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index 44c7a2445a76a3b47e6a74fb48a385b582782c3e..d86328b8da04daf8691f79090c877afe1e709f63 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -5,7 +5,11 @@ import time
 
 import torch
 from torch.distributed import checkpoint as dcp
+# The commented modules require PyTorch ≥2.2 and enable more modern
+# APIs.
+# from torch.distributed import device_mesh
 from torch.distributed import fsdp
+# from torch.distributed.checkpoint import state_dict as dist_state_dict
 import torchvision
 
 
@@ -104,6 +108,16 @@ def save_model_singular(model, *args, **kwargs):
             save_policy,
     ):
         cpu_state = model.state_dict()
+    # For PyTorch versions ≥2.2:
+    # state_dict_options = dist_state_dict.StateDictOptions(
+    #     full_state_dict=True,
+    #     cpu_offload=True,
+    # )
+    # cpu_state = dist_state_dict.get_model_state_dict(
+    #     model,
+    #     options=state_dict_options,
+    # )
+
     # We do *not* want to write to the same location with multiple
     # processes at the same time.
     if is_root_process():
@@ -123,10 +137,25 @@ def save_model(model, save_dir):
             state_dict_config,
     ):
         cp_state_dict = {'model': model.state_dict()}
+
     dcp.save_state_dict(
         cp_state_dict,
         dcp.FileSystemWriter(save_dir),
     )
+    # For PyTorch versions ≥2.2:
+    # state_dict_options = dist_state_dict.StateDictOptions(
+    #     cpu_offload=False,
+    # )
+    # model_state_dict = dist_state_dict.get_model_state_dict(
+    #     model,
+    #     options=state_dict_options,
+    # )
+    # cp_state_dict = {'model': model_state_dict}
+    #
+    # dcp.save(
+    #     cp_state_dict,
+    #     storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True),
+    # )
 
 
 def load_model_singular(model, *args, **kwargs):
@@ -149,11 +178,27 @@ def load_model(model, load_dir):
             state_dict_config,
     ):
         cp_state_dict = {'model': model.state_dict()}
+
     dcp.load_state_dict(
         cp_state_dict,
         dcp.FileSystemReader(load_dir),
     )
     model.load_state_dict(cp_state_dict['model'])
+    # For PyTorch versions ≥2.2:
+    # state_dict_options = dist_state_dict.StateDictOptions(
+    #     cpu_offload=False,
+    # )
+    # model_state_dict = dist_state_dict.get_model_state_dict(
+    #     model,
+    #     options=state_dict_options,
+    # )
+    # cp_state_dict = {'model': model_state_dict}
+    #
+    # dcp.load(
+    #     cp_state_dict,
+    #     storage_reader=dcp.FileSystemReader(load_dir),
+    # )
+    # dist_state_dict.set_model_state_dict(model, cp_state_dict['model'])
 
 
 def all_reduce_avg(tensor):
@@ -276,9 +321,16 @@ def main():
     train_dset, valid_dset, test_dset = prepare_datasets(args, device)
 
     model = build_model()
+    # For PyTorch versions ≥2.2, `mesh_1d` should be included to enable
+    # use of more modern implementations:
+    # mesh_1d = device_mesh.init_device_mesh(
+    #     "cuda",
+    #     (torch.distributed.get_world_size(),),
+    # )
     model = fsdp.FullyShardedDataParallel(
         model,
         device_id=local_rank,
+        # device_mesh=mesh_1d,
         auto_wrap_policy=functools.partial(
             fsdp.wrap.size_based_auto_wrap_policy,
             # Wrap every 1B parameters.