From 7e183b8cb9f9e236c2531a0ea3650372f6860902 Mon Sep 17 00:00:00 2001
From: janEbert <janpublicebert@posteo.net>
Date: Mon, 2 Jun 2025 16:30:01 +0200
Subject: [PATCH] Use FSDP2

Instead of FSDP1.
---
 pytorch-fsdp-example/main.py          | 15 +++++----------
 pytorch-fsdp-example/requirements.txt |  2 +-
 2 files changed, 6 insertions(+), 11 deletions(-)

diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index 60a24c8..6d6eeb3 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -231,11 +231,9 @@ def distribute_model(model, args):
     """Distribute the model across the different processes using Fully
     Sharded Data Parallelism (FSDP).
     """
-    local_rank = get_local_rank()
     # Set up FSDP or HSDP.
     if args.num_fsdp_replicas is None:
         fsdp_mesh_dims = (torch.distributed.get_world_size(),)
-        sharding_strategy = fsdp.ShardingStrategy.FULL_SHARD
     else:
         assert (
             torch.distributed.get_world_size() % args.num_fsdp_replicas
@@ -244,7 +242,6 @@ def distribute_model(model, args):
         fsdp_shards_per_replica = \
             torch.distributed.get_world_size() // args.num_fsdp_replicas
         fsdp_mesh_dims = (args.num_fsdp_replicas, fsdp_shards_per_replica)
-        sharding_strategy = fsdp.ShardingStrategy.HYBRID_SHARD
     fsdp_mesh = device_mesh.init_device_mesh("cuda", fsdp_mesh_dims)
 
     if args.model_type == 'resnet':
@@ -263,13 +260,11 @@ def distribute_model(model, args):
     else:
         raise ValueError(f'unknown model type "{args.model_type}"')
 
-    model = fsdp.FullyShardedDataParallel(
-        model,
-        device_id=local_rank,
-        device_mesh=fsdp_mesh,
-        sharding_strategy=sharding_strategy,
-        auto_wrap_policy=auto_wrap_policy,
-    )
+    fsdp_kwargs = dict(mesh=fsdp_mesh)
+    for module in model.modules():
+        if auto_wrap_policy(module):
+            fsdp.fully_shard(module, **fsdp_kwargs)
+    fsdp.fully_shard(model, **fsdp_kwargs)
     return model
 
 
diff --git a/pytorch-fsdp-example/requirements.txt b/pytorch-fsdp-example/requirements.txt
index 5c68afa..823de04 100644
--- a/pytorch-fsdp-example/requirements.txt
+++ b/pytorch-fsdp-example/requirements.txt
@@ -1,3 +1,3 @@
-torch>=2.2,<3
+torch>=2.6,<3
 torchrun_jsc>=0.0.15
 torchvision>=0.13
-- 
GitLab