diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index d86328b8da04daf8691f79090c877afe1e709f63..d6ea289896993bdf05b6aa8475766cb808915ad1 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -5,11 +5,9 @@ import time import torch from torch.distributed import checkpoint as dcp -# The commented modules require PyTorch ≥2.2 and enable more modern -# APIs. -# from torch.distributed import device_mesh +from torch.distributed import device_mesh from torch.distributed import fsdp -# from torch.distributed.checkpoint import state_dict as dist_state_dict +from torch.distributed.checkpoint import state_dict as dist_state_dict import torchvision @@ -100,23 +98,14 @@ def save_model_singular(model, *args, **kwargs): other given arguments to `torch.save` to save the model, but only on the root process. """ - save_policy = fsdp.FullStateDictConfig( - offload_to_cpu=True, rank0_only=True) - with fsdp.FullyShardedDataParallel.state_dict_type( - model, - fsdp.StateDictType.FULL_STATE_DICT, - save_policy, - ): - cpu_state = model.state_dict() - # For PyTorch versions ≥2.2: - # state_dict_options = dist_state_dict.StateDictOptions( - # full_state_dict=True, - # cpu_offload=True, - # ) - # cpu_state = dist_state_dict.get_model_state_dict( - # model, - # options=state_dict_options, - # ) + state_dict_options = dist_state_dict.StateDictOptions( + full_state_dict=True, + cpu_offload=True, + ) + cpu_state = dist_state_dict.get_model_state_dict( + model, + options=state_dict_options, + ) # We do *not* want to write to the same location with multiple # processes at the same time. @@ -130,32 +119,19 @@ def save_model(model, save_dir): distributed checkpoint means that the checkpoint will be split into individual files, one for each process. """ - state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False) - with fsdp.FullyShardedDataParallel.state_dict_type( - model, - fsdp.StateDictType.SHARDED_STATE_DICT, - state_dict_config, - ): - cp_state_dict = {'model': model.state_dict()} - - dcp.save_state_dict( + state_dict_options = dist_state_dict.StateDictOptions( + cpu_offload=False, + ) + model_state_dict = dist_state_dict.get_model_state_dict( + model, + options=state_dict_options, + ) + cp_state_dict = {'model': model_state_dict} + + dcp.save( cp_state_dict, - dcp.FileSystemWriter(save_dir), + storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True), ) - # For PyTorch versions ≥2.2: - # state_dict_options = dist_state_dict.StateDictOptions( - # cpu_offload=False, - # ) - # model_state_dict = dist_state_dict.get_model_state_dict( - # model, - # options=state_dict_options, - # ) - # cp_state_dict = {'model': model_state_dict} - # - # dcp.save( - # cp_state_dict, - # storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True), - # ) def load_model_singular(model, *args, **kwargs): @@ -171,34 +147,20 @@ def load_model(model, load_dir): """Set the given model's state dictionary in-place from the given distributed checkpoint directory. """ - state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False) - with fsdp.FullyShardedDataParallel.state_dict_type( - model, - fsdp.StateDictType.SHARDED_STATE_DICT, - state_dict_config, - ): - cp_state_dict = {'model': model.state_dict()} - - dcp.load_state_dict( + state_dict_options = dist_state_dict.StateDictOptions( + cpu_offload=False, + ) + model_state_dict = dist_state_dict.get_model_state_dict( + model, + options=state_dict_options, + ) + cp_state_dict = {'model': model_state_dict} + + dcp.load( cp_state_dict, - dcp.FileSystemReader(load_dir), + storage_reader=dcp.FileSystemReader(load_dir), ) - model.load_state_dict(cp_state_dict['model']) - # For PyTorch versions ≥2.2: - # state_dict_options = dist_state_dict.StateDictOptions( - # cpu_offload=False, - # ) - # model_state_dict = dist_state_dict.get_model_state_dict( - # model, - # options=state_dict_options, - # ) - # cp_state_dict = {'model': model_state_dict} - # - # dcp.load( - # cp_state_dict, - # storage_reader=dcp.FileSystemReader(load_dir), - # ) - # dist_state_dict.set_model_state_dict(model, cp_state_dict['model']) + dist_state_dict.set_model_state_dict(model, cp_state_dict['model']) def all_reduce_avg(tensor): @@ -321,16 +283,14 @@ def main(): train_dset, valid_dset, test_dset = prepare_datasets(args, device) model = build_model() - # For PyTorch versions ≥2.2, `mesh_1d` should be included to enable - # use of more modern implementations: - # mesh_1d = device_mesh.init_device_mesh( - # "cuda", - # (torch.distributed.get_world_size(),), - # ) + mesh_1d = device_mesh.init_device_mesh( + "cuda", + (torch.distributed.get_world_size(),), + ) model = fsdp.FullyShardedDataParallel( model, device_id=local_rank, - # device_mesh=mesh_1d, + device_mesh=mesh_1d, auto_wrap_policy=functools.partial( fsdp.wrap.size_based_auto_wrap_policy, # Wrap every 1B parameters.