diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index c72dd78bbb7569e48bf7e0c34f16ad89e226022e..6a24d934ec6bb312b225d51b868fe9506a5df605 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -67,6 +67,17 @@ def parse_args(): default=0, help='Random number generator initialization value.', ) + parser.add_argument( + '--num-fsdp-replicas', + type=int, + help=( + 'How many FSDP replicas to use for hybrid sharded data ' + 'parallelism (HSDP). The model will be sharded into ' + '`world_size / num_fsdp_replicas` partitions per replica. ' + 'Gradients will be all-reduced across the replicas. ' + 'If not given, use standard FSDP.' + ), + ) args = parser.parse_args() return args @@ -283,14 +294,25 @@ def main(): train_dset, valid_dset, test_dset = prepare_datasets(args, device) model = build_model() - mesh_1d = device_mesh.init_device_mesh( - "cuda", - (torch.distributed.get_world_size(),), - ) + # Set up FSDP or HSDP. + if args.num_fsdp_replicas is None: + fsdp_mesh_dims = (torch.distributed.get_world_size(),) + sharding_strategy = fsdp.ShardingStrategy.FULL_SHARD + else: + assert ( + torch.distributed.get_world_size() % args.num_fsdp_replicas + == 0 + ), 'world size must be divisible by number of FSDP replicas' + fsdp_shards_per_replica = \ + torch.distributed.get_world_size() // args.num_fsdp_replicas + fsdp_mesh_dims = (args.num_fsdp_replicas, fsdp_shards_per_replica) + sharding_strategy = fsdp.ShardingStrategy.HYBRID_SHARD + fsdp_mesh = device_mesh.init_device_mesh("cuda", fsdp_mesh_dims) model = fsdp.FullyShardedDataParallel( model, device_id=local_rank, - device_mesh=mesh_1d, + device_mesh=fsdp_mesh, + sharding_strategy=sharding_strategy, auto_wrap_policy=functools.partial( fsdp.wrap.size_based_auto_wrap_policy, # Wrap every 1B parameters.