diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index 60a24c84e1b310a2fb4a556cf55db74c0c135f35..6d6eeb3091c04907fac3f221918fbf72480536e9 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -231,11 +231,9 @@ def distribute_model(model, args): """Distribute the model across the different processes using Fully Sharded Data Parallelism (FSDP). """ - local_rank = get_local_rank() # Set up FSDP or HSDP. if args.num_fsdp_replicas is None: fsdp_mesh_dims = (torch.distributed.get_world_size(),) - sharding_strategy = fsdp.ShardingStrategy.FULL_SHARD else: assert ( torch.distributed.get_world_size() % args.num_fsdp_replicas @@ -244,7 +242,6 @@ def distribute_model(model, args): fsdp_shards_per_replica = \ torch.distributed.get_world_size() // args.num_fsdp_replicas fsdp_mesh_dims = (args.num_fsdp_replicas, fsdp_shards_per_replica) - sharding_strategy = fsdp.ShardingStrategy.HYBRID_SHARD fsdp_mesh = device_mesh.init_device_mesh("cuda", fsdp_mesh_dims) if args.model_type == 'resnet': @@ -263,13 +260,11 @@ def distribute_model(model, args): else: raise ValueError(f'unknown model type "{args.model_type}"') - model = fsdp.FullyShardedDataParallel( - model, - device_id=local_rank, - device_mesh=fsdp_mesh, - sharding_strategy=sharding_strategy, - auto_wrap_policy=auto_wrap_policy, - ) + fsdp_kwargs = dict(mesh=fsdp_mesh) + for module in model.modules(): + if auto_wrap_policy(module): + fsdp.fully_shard(module, **fsdp_kwargs) + fsdp.fully_shard(model, **fsdp_kwargs) return model diff --git a/pytorch-fsdp-example/requirements.txt b/pytorch-fsdp-example/requirements.txt index 5c68afaf98da12508f0489eb3b63a74e674ea436..823de046e4487996bcda3460f0f82a646423ab20 100644 --- a/pytorch-fsdp-example/requirements.txt +++ b/pytorch-fsdp-example/requirements.txt @@ -1,3 +1,3 @@ -torch>=2.2,<3 +torch>=2.6,<3 torchrun_jsc>=0.0.15 torchvision>=0.13