From 1707daa17a5bd5024064dcf5dd1c3c95fcddcfab Mon Sep 17 00:00:00 2001
From: janEbert <janpublicebert@posteo.net>
Date: Tue, 9 Jul 2024 12:41:27 +0200
Subject: [PATCH] Use a simple auto wrapping policy

This way, the example actually shows a way to properly distribute a
module using FSDP.
---
 pytorch-fsdp-example/main.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index a5cc9a6..f8ebfbf 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -264,6 +264,11 @@ def main():
     model = fsdp.FullyShardedDataParallel(
         model,
         device_id=local_rank,
+        auto_wrap_policy=functools.partial(
+            fsdp.wrap.size_based_auto_wrap_policy,
+            # Wrap every 1B parameters.
+            min_num_params=int(1e9),
+        ),
     )
     loss_func = torch.nn.CrossEntropyLoss()
 
-- 
GitLab