Skip to content
Snippets Groups Projects
Commit 2090b798 authored by Jan Ebert's avatar Jan Ebert
Browse files

Fix wrap policy usage

Cannot just re-use the existing methods because they rely on
`_recursive_wrap`'s logic.
parent 7da56b3e
Branches
No related tags found
No related merge requests found
......@@ -247,22 +247,23 @@ def distribute_model(model, args):
if args.model_type == 'resnet':
# We could also use the `ModuleWrapPolicy` here, but this way we
# show a method that works with arbitrary models.
auto_wrap_policy = functools.partial(
fsdp.wrap.size_based_auto_wrap_policy,
def wrap_policy(module):
num_params = sum(p.numel() for p in module.parameters())
# Wrap every 1B parameters.
min_num_params=int(1e9),
)
return num_params >= int(1e9)
elif args.model_type == 'vit':
# Each Transformer block becomes one FSDP unit.
auto_wrap_policy = fsdp.wrap.ModuleWrapPolicy({
def wrap_policy(module):
return isinstance(
module,
torchvision.models.vision_transformer.EncoderBlock,
})
)
else:
raise ValueError(f'unknown model type "{args.model_type}"')
fsdp_kwargs = dict(mesh=fsdp_mesh)
for module in model.modules():
if auto_wrap_policy(module):
if wrap_policy(module):
fsdp.fully_shard(module, **fsdp_kwargs)
fsdp.fully_shard(model, **fsdp_kwargs)
return model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment