diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index 5843373362c3d5486f84836d7d4470a54f4a80d0..991e6fc95d38be1c4ce8c841e20529a16ff56c66 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -4,7 +4,7 @@ import os
 import time
 
 import torch
-from torch.distributed import checkpoint as dist_checkpoint
+from torch.distributed import checkpoint as dcp
 from torch.distributed import fsdp
 import torchvision
 
@@ -123,9 +123,9 @@ def save_model(model, save_dir):
             state_dict_config,
     ):
         cp_state_dict = {'model': model.state_dict()}
-    dist_checkpoint.save_state_dict(
+    dcp.save_state_dict(
         cp_state_dict,
-        dist_checkpoint.FileSystemWriter(save_dir),
+        dcp.FileSystemWriter(save_dir),
     )
 
 
@@ -140,9 +140,9 @@ def load_model(model, load_dir):
             state_dict_config,
     ):
         cp_state_dict = {'model': model.state_dict()}
-    dist_checkpoint.load_state_dict(
+    dcp.load_state_dict(
         cp_state_dict,
-        dist_checkpoint.FileSystemReader(load_dir),
+        dcp.FileSystemReader(load_dir),
     )
     model.load_state_dict(cp_state_dict['model'])