Skip to content
Snippets Groups Projects
Commit 1707daa1 authored by Jan Ebert's avatar Jan Ebert
Browse files

Use a simple auto wrapping policy

This way, the example actually shows a way to properly distribute a
module using FSDP.
parent c040e4d9
Branches
Tags
No related merge requests found
...@@ -264,6 +264,11 @@ def main(): ...@@ -264,6 +264,11 @@ def main():
model = fsdp.FullyShardedDataParallel( model = fsdp.FullyShardedDataParallel(
model, model,
device_id=local_rank, device_id=local_rank,
auto_wrap_policy=functools.partial(
fsdp.wrap.size_based_auto_wrap_policy,
# Wrap every 1B parameters.
min_num_params=int(1e9),
),
) )
loss_func = torch.nn.CrossEntropyLoss() loss_func = torch.nn.CrossEntropyLoss()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment