diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index a5cc9a62adfaf996378101c4118ee57ab039e10a..f8ebfbfcb718154686749898fca193e13de32f6b 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -264,6 +264,11 @@ def main(): model = fsdp.FullyShardedDataParallel( model, device_id=local_rank, + auto_wrap_policy=functools.partial( + fsdp.wrap.size_based_auto_wrap_policy, + # Wrap every 1B parameters. + min_num_params=int(1e9), + ), ) loss_func = torch.nn.CrossEntropyLoss()