diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index 60a24c84e1b310a2fb4a556cf55db74c0c135f35..6d6eeb3091c04907fac3f221918fbf72480536e9 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -231,11 +231,9 @@ def distribute_model(model, args):
     """Distribute the model across the different processes using Fully
     Sharded Data Parallelism (FSDP).
     """
-    local_rank = get_local_rank()
     # Set up FSDP or HSDP.
     if args.num_fsdp_replicas is None:
         fsdp_mesh_dims = (torch.distributed.get_world_size(),)
-        sharding_strategy = fsdp.ShardingStrategy.FULL_SHARD
     else:
         assert (
             torch.distributed.get_world_size() % args.num_fsdp_replicas
@@ -244,7 +242,6 @@ def distribute_model(model, args):
         fsdp_shards_per_replica = \
             torch.distributed.get_world_size() // args.num_fsdp_replicas
         fsdp_mesh_dims = (args.num_fsdp_replicas, fsdp_shards_per_replica)
-        sharding_strategy = fsdp.ShardingStrategy.HYBRID_SHARD
     fsdp_mesh = device_mesh.init_device_mesh("cuda", fsdp_mesh_dims)
 
     if args.model_type == 'resnet':
@@ -263,13 +260,11 @@ def distribute_model(model, args):
     else:
         raise ValueError(f'unknown model type "{args.model_type}"')
 
-    model = fsdp.FullyShardedDataParallel(
-        model,
-        device_id=local_rank,
-        device_mesh=fsdp_mesh,
-        sharding_strategy=sharding_strategy,
-        auto_wrap_policy=auto_wrap_policy,
-    )
+    fsdp_kwargs = dict(mesh=fsdp_mesh)
+    for module in model.modules():
+        if auto_wrap_policy(module):
+            fsdp.fully_shard(module, **fsdp_kwargs)
+    fsdp.fully_shard(model, **fsdp_kwargs)
     return model
 
 
diff --git a/pytorch-fsdp-example/requirements.txt b/pytorch-fsdp-example/requirements.txt
index 5c68afaf98da12508f0489eb3b63a74e674ea436..823de046e4487996bcda3460f0f82a646423ab20 100644
--- a/pytorch-fsdp-example/requirements.txt
+++ b/pytorch-fsdp-example/requirements.txt
@@ -1,3 +1,3 @@
-torch>=2.2,<3
+torch>=2.6,<3
 torchrun_jsc>=0.0.15
 torchvision>=0.13