From 7e183b8cb9f9e236c2531a0ea3650372f6860902 Mon Sep 17 00:00:00 2001 From: janEbert <janpublicebert@posteo.net> Date: Mon, 2 Jun 2025 16:30:01 +0200 Subject: [PATCH] Use FSDP2 Instead of FSDP1. --- pytorch-fsdp-example/main.py | 15 +++++---------- pytorch-fsdp-example/requirements.txt | 2 +- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index 60a24c8..6d6eeb3 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -231,11 +231,9 @@ def distribute_model(model, args): """Distribute the model across the different processes using Fully Sharded Data Parallelism (FSDP). """ - local_rank = get_local_rank() # Set up FSDP or HSDP. if args.num_fsdp_replicas is None: fsdp_mesh_dims = (torch.distributed.get_world_size(),) - sharding_strategy = fsdp.ShardingStrategy.FULL_SHARD else: assert ( torch.distributed.get_world_size() % args.num_fsdp_replicas @@ -244,7 +242,6 @@ def distribute_model(model, args): fsdp_shards_per_replica = \ torch.distributed.get_world_size() // args.num_fsdp_replicas fsdp_mesh_dims = (args.num_fsdp_replicas, fsdp_shards_per_replica) - sharding_strategy = fsdp.ShardingStrategy.HYBRID_SHARD fsdp_mesh = device_mesh.init_device_mesh("cuda", fsdp_mesh_dims) if args.model_type == 'resnet': @@ -263,13 +260,11 @@ def distribute_model(model, args): else: raise ValueError(f'unknown model type "{args.model_type}"') - model = fsdp.FullyShardedDataParallel( - model, - device_id=local_rank, - device_mesh=fsdp_mesh, - sharding_strategy=sharding_strategy, - auto_wrap_policy=auto_wrap_policy, - ) + fsdp_kwargs = dict(mesh=fsdp_mesh) + for module in model.modules(): + if auto_wrap_policy(module): + fsdp.fully_shard(module, **fsdp_kwargs) + fsdp.fully_shard(model, **fsdp_kwargs) return model diff --git a/pytorch-fsdp-example/requirements.txt b/pytorch-fsdp-example/requirements.txt index 5c68afa..823de04 100644 --- a/pytorch-fsdp-example/requirements.txt +++ b/pytorch-fsdp-example/requirements.txt @@ -1,3 +1,3 @@ -torch>=2.2,<3 +torch>=2.6,<3 torchrun_jsc>=0.0.15 torchvision>=0.13 -- GitLab