Skip to content
Snippets Groups Projects
Commit 1707daa1 authored by Jan Ebert's avatar Jan Ebert
Browse files

Use a simple auto wrapping policy

This way, the example actually shows a way to properly distribute a
module using FSDP.
parent c040e4d9
Branches
No related tags found
No related merge requests found
......@@ -264,6 +264,11 @@ def main():
model = fsdp.FullyShardedDataParallel(
model,
device_id=local_rank,
auto_wrap_policy=functools.partial(
fsdp.wrap.size_based_auto_wrap_policy,
# Wrap every 1B parameters.
min_num_params=int(1e9),
),
)
loss_func = torch.nn.CrossEntropyLoss()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment