diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index d86328b8da04daf8691f79090c877afe1e709f63..d6ea289896993bdf05b6aa8475766cb808915ad1 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -5,11 +5,9 @@ import time
 
 import torch
 from torch.distributed import checkpoint as dcp
-# The commented modules require PyTorch ≥2.2 and enable more modern
-# APIs.
-# from torch.distributed import device_mesh
+from torch.distributed import device_mesh
 from torch.distributed import fsdp
-# from torch.distributed.checkpoint import state_dict as dist_state_dict
+from torch.distributed.checkpoint import state_dict as dist_state_dict
 import torchvision
 
 
@@ -100,23 +98,14 @@ def save_model_singular(model, *args, **kwargs):
     other given arguments to `torch.save` to save the model, but only on
     the root process.
     """
-    save_policy = fsdp.FullStateDictConfig(
-        offload_to_cpu=True, rank0_only=True)
-    with fsdp.FullyShardedDataParallel.state_dict_type(
-            model,
-            fsdp.StateDictType.FULL_STATE_DICT,
-            save_policy,
-    ):
-        cpu_state = model.state_dict()
-    # For PyTorch versions ≥2.2:
-    # state_dict_options = dist_state_dict.StateDictOptions(
-    #     full_state_dict=True,
-    #     cpu_offload=True,
-    # )
-    # cpu_state = dist_state_dict.get_model_state_dict(
-    #     model,
-    #     options=state_dict_options,
-    # )
+    state_dict_options = dist_state_dict.StateDictOptions(
+        full_state_dict=True,
+        cpu_offload=True,
+    )
+    cpu_state = dist_state_dict.get_model_state_dict(
+        model,
+        options=state_dict_options,
+    )
 
     # We do *not* want to write to the same location with multiple
     # processes at the same time.
@@ -130,32 +119,19 @@ def save_model(model, save_dir):
     distributed checkpoint means that the checkpoint will be split into
     individual files, one for each process.
     """
-    state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False)
-    with fsdp.FullyShardedDataParallel.state_dict_type(
-            model,
-            fsdp.StateDictType.SHARDED_STATE_DICT,
-            state_dict_config,
-    ):
-        cp_state_dict = {'model': model.state_dict()}
-
-    dcp.save_state_dict(
+    state_dict_options = dist_state_dict.StateDictOptions(
+        cpu_offload=False,
+    )
+    model_state_dict = dist_state_dict.get_model_state_dict(
+        model,
+        options=state_dict_options,
+    )
+    cp_state_dict = {'model': model_state_dict}
+    
+    dcp.save(
         cp_state_dict,
-        dcp.FileSystemWriter(save_dir),
+        storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True),
     )
-    # For PyTorch versions ≥2.2:
-    # state_dict_options = dist_state_dict.StateDictOptions(
-    #     cpu_offload=False,
-    # )
-    # model_state_dict = dist_state_dict.get_model_state_dict(
-    #     model,
-    #     options=state_dict_options,
-    # )
-    # cp_state_dict = {'model': model_state_dict}
-    #
-    # dcp.save(
-    #     cp_state_dict,
-    #     storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True),
-    # )
 
 
 def load_model_singular(model, *args, **kwargs):
@@ -171,34 +147,20 @@ def load_model(model, load_dir):
     """Set the given model's state dictionary in-place from the given
     distributed checkpoint directory.
     """
-    state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False)
-    with fsdp.FullyShardedDataParallel.state_dict_type(
-            model,
-            fsdp.StateDictType.SHARDED_STATE_DICT,
-            state_dict_config,
-    ):
-        cp_state_dict = {'model': model.state_dict()}
-
-    dcp.load_state_dict(
+    state_dict_options = dist_state_dict.StateDictOptions(
+        cpu_offload=False,
+    )
+    model_state_dict = dist_state_dict.get_model_state_dict(
+        model,
+        options=state_dict_options,
+    )
+    cp_state_dict = {'model': model_state_dict}
+    
+    dcp.load(
         cp_state_dict,
-        dcp.FileSystemReader(load_dir),
+        storage_reader=dcp.FileSystemReader(load_dir),
     )
-    model.load_state_dict(cp_state_dict['model'])
-    # For PyTorch versions ≥2.2:
-    # state_dict_options = dist_state_dict.StateDictOptions(
-    #     cpu_offload=False,
-    # )
-    # model_state_dict = dist_state_dict.get_model_state_dict(
-    #     model,
-    #     options=state_dict_options,
-    # )
-    # cp_state_dict = {'model': model_state_dict}
-    #
-    # dcp.load(
-    #     cp_state_dict,
-    #     storage_reader=dcp.FileSystemReader(load_dir),
-    # )
-    # dist_state_dict.set_model_state_dict(model, cp_state_dict['model'])
+    dist_state_dict.set_model_state_dict(model, cp_state_dict['model'])
 
 
 def all_reduce_avg(tensor):
@@ -321,16 +283,14 @@ def main():
     train_dset, valid_dset, test_dset = prepare_datasets(args, device)
 
     model = build_model()
-    # For PyTorch versions ≥2.2, `mesh_1d` should be included to enable
-    # use of more modern implementations:
-    # mesh_1d = device_mesh.init_device_mesh(
-    #     "cuda",
-    #     (torch.distributed.get_world_size(),),
-    # )
+    mesh_1d = device_mesh.init_device_mesh(
+        "cuda",
+        (torch.distributed.get_world_size(),),
+    )
     model = fsdp.FullyShardedDataParallel(
         model,
         device_id=local_rank,
-        # device_mesh=mesh_1d,
+        device_mesh=mesh_1d,
         auto_wrap_policy=functools.partial(
             fsdp.wrap.size_based_auto_wrap_policy,
             # Wrap every 1B parameters.