From 1707daa17a5bd5024064dcf5dd1c3c95fcddcfab Mon Sep 17 00:00:00 2001 From: janEbert <janpublicebert@posteo.net> Date: Tue, 9 Jul 2024 12:41:27 +0200 Subject: [PATCH] Use a simple auto wrapping policy This way, the example actually shows a way to properly distribute a module using FSDP. --- pytorch-fsdp-example/main.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index a5cc9a6..f8ebfbf 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -264,6 +264,11 @@ def main(): model = fsdp.FullyShardedDataParallel( model, device_id=local_rank, + auto_wrap_policy=functools.partial( + fsdp.wrap.size_based_auto_wrap_policy, + # Wrap every 1B parameters. + min_num_params=int(1e9), + ), ) loss_func = torch.nn.CrossEntropyLoss() -- GitLab