Skip to content
Snippets Groups Projects
Commit 7e183b8c authored by Jan Ebert's avatar Jan Ebert
Browse files

Use FSDP2

Instead of FSDP1.
parent e90d4094
Branches
No related tags found
No related merge requests found
......@@ -231,11 +231,9 @@ def distribute_model(model, args):
"""Distribute the model across the different processes using Fully
Sharded Data Parallelism (FSDP).
"""
local_rank = get_local_rank()
# Set up FSDP or HSDP.
if args.num_fsdp_replicas is None:
fsdp_mesh_dims = (torch.distributed.get_world_size(),)
sharding_strategy = fsdp.ShardingStrategy.FULL_SHARD
else:
assert (
torch.distributed.get_world_size() % args.num_fsdp_replicas
......@@ -244,7 +242,6 @@ def distribute_model(model, args):
fsdp_shards_per_replica = \
torch.distributed.get_world_size() // args.num_fsdp_replicas
fsdp_mesh_dims = (args.num_fsdp_replicas, fsdp_shards_per_replica)
sharding_strategy = fsdp.ShardingStrategy.HYBRID_SHARD
fsdp_mesh = device_mesh.init_device_mesh("cuda", fsdp_mesh_dims)
if args.model_type == 'resnet':
......@@ -263,13 +260,11 @@ def distribute_model(model, args):
else:
raise ValueError(f'unknown model type "{args.model_type}"')
model = fsdp.FullyShardedDataParallel(
model,
device_id=local_rank,
device_mesh=fsdp_mesh,
sharding_strategy=sharding_strategy,
auto_wrap_policy=auto_wrap_policy,
)
fsdp_kwargs = dict(mesh=fsdp_mesh)
for module in model.modules():
if auto_wrap_policy(module):
fsdp.fully_shard(module, **fsdp_kwargs)
fsdp.fully_shard(model, **fsdp_kwargs)
return model
......
torch>=2.2,<3
torch>=2.6,<3
torchrun_jsc>=0.0.15
torchvision>=0.13
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment