Skip to content
Snippets Groups Projects
Commit 99836847 authored by Jan Ebert's avatar Jan Ebert
Browse files

Add optional HSDP support

Documentation in README is missing for now, but will be added in the
future.
parent 3fe25198
No related branches found
No related tags found
No related merge requests found
...@@ -67,6 +67,17 @@ def parse_args(): ...@@ -67,6 +67,17 @@ def parse_args():
default=0, default=0,
help='Random number generator initialization value.', help='Random number generator initialization value.',
) )
parser.add_argument(
'--num-fsdp-replicas',
type=int,
help=(
'How many FSDP replicas to use for hybrid sharded data '
'parallelism (HSDP). The model will be sharded into '
'`world_size / num_fsdp_replicas` partitions per replica. '
'Gradients will be all-reduced across the replicas. '
'If not given, use standard FSDP.'
),
)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -283,14 +294,25 @@ def main(): ...@@ -283,14 +294,25 @@ def main():
train_dset, valid_dset, test_dset = prepare_datasets(args, device) train_dset, valid_dset, test_dset = prepare_datasets(args, device)
model = build_model() model = build_model()
mesh_1d = device_mesh.init_device_mesh( # Set up FSDP or HSDP.
"cuda", if args.num_fsdp_replicas is None:
(torch.distributed.get_world_size(),), fsdp_mesh_dims = (torch.distributed.get_world_size(),)
) sharding_strategy = fsdp.ShardingStrategy.FULL_SHARD
else:
assert (
torch.distributed.get_world_size() % args.num_fsdp_replicas
== 0
), 'world size must be divisible by number of FSDP replicas'
fsdp_shards_per_replica = \
torch.distributed.get_world_size() // args.num_fsdp_replicas
fsdp_mesh_dims = (args.num_fsdp_replicas, fsdp_shards_per_replica)
sharding_strategy = fsdp.ShardingStrategy.HYBRID_SHARD
fsdp_mesh = device_mesh.init_device_mesh("cuda", fsdp_mesh_dims)
model = fsdp.FullyShardedDataParallel( model = fsdp.FullyShardedDataParallel(
model, model,
device_id=local_rank, device_id=local_rank,
device_mesh=mesh_1d, device_mesh=fsdp_mesh,
sharding_strategy=sharding_strategy,
auto_wrap_policy=functools.partial( auto_wrap_policy=functools.partial(
fsdp.wrap.size_based_auto_wrap_policy, fsdp.wrap.size_based_auto_wrap_policy,
# Wrap every 1B parameters. # Wrap every 1B parameters.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment