diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index a5cc9a62adfaf996378101c4118ee57ab039e10a..f8ebfbfcb718154686749898fca193e13de32f6b 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -264,6 +264,11 @@ def main():
     model = fsdp.FullyShardedDataParallel(
         model,
         device_id=local_rank,
+        auto_wrap_policy=functools.partial(
+            fsdp.wrap.size_based_auto_wrap_policy,
+            # Wrap every 1B parameters.
+            min_num_params=int(1e9),
+        ),
     )
     loss_func = torch.nn.CrossEntropyLoss()