diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index c72dd78bbb7569e48bf7e0c34f16ad89e226022e..6a24d934ec6bb312b225d51b868fe9506a5df605 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -67,6 +67,17 @@ def parse_args():
         default=0,
         help='Random number generator initialization value.',
     )
+    parser.add_argument(
+        '--num-fsdp-replicas',
+        type=int,
+        help=(
+            'How many FSDP replicas to use for hybrid sharded data '
+            'parallelism (HSDP). The model will be sharded into '
+            '`world_size / num_fsdp_replicas` partitions per replica. '
+            'Gradients will be all-reduced across the replicas. '
+            'If not given, use standard FSDP.'
+        ),
+    )
 
     args = parser.parse_args()
     return args
@@ -283,14 +294,25 @@ def main():
     train_dset, valid_dset, test_dset = prepare_datasets(args, device)
 
     model = build_model()
-    mesh_1d = device_mesh.init_device_mesh(
-        "cuda",
-        (torch.distributed.get_world_size(),),
-    )
+    # Set up FSDP or HSDP.
+    if args.num_fsdp_replicas is None:
+        fsdp_mesh_dims = (torch.distributed.get_world_size(),)
+        sharding_strategy = fsdp.ShardingStrategy.FULL_SHARD
+    else:
+        assert (
+            torch.distributed.get_world_size() % args.num_fsdp_replicas
+            == 0
+        ), 'world size must be divisible by number of FSDP replicas'
+        fsdp_shards_per_replica = \
+            torch.distributed.get_world_size() // args.num_fsdp_replicas
+        fsdp_mesh_dims = (args.num_fsdp_replicas, fsdp_shards_per_replica)
+        sharding_strategy = fsdp.ShardingStrategy.HYBRID_SHARD
+    fsdp_mesh = device_mesh.init_device_mesh("cuda", fsdp_mesh_dims)
     model = fsdp.FullyShardedDataParallel(
         model,
         device_id=local_rank,
-        device_mesh=mesh_1d,
+        device_mesh=fsdp_mesh,
+        sharding_strategy=sharding_strategy,
         auto_wrap_policy=functools.partial(
             fsdp.wrap.size_based_auto_wrap_policy,
             # Wrap every 1B parameters.