diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index 44c7a2445a76a3b47e6a74fb48a385b582782c3e..d86328b8da04daf8691f79090c877afe1e709f63 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -5,7 +5,11 @@ import time import torch from torch.distributed import checkpoint as dcp +# The commented modules require PyTorch ≥2.2 and enable more modern +# APIs. +# from torch.distributed import device_mesh from torch.distributed import fsdp +# from torch.distributed.checkpoint import state_dict as dist_state_dict import torchvision @@ -104,6 +108,16 @@ def save_model_singular(model, *args, **kwargs): save_policy, ): cpu_state = model.state_dict() + # For PyTorch versions ≥2.2: + # state_dict_options = dist_state_dict.StateDictOptions( + # full_state_dict=True, + # cpu_offload=True, + # ) + # cpu_state = dist_state_dict.get_model_state_dict( + # model, + # options=state_dict_options, + # ) + # We do *not* want to write to the same location with multiple # processes at the same time. if is_root_process(): @@ -123,10 +137,25 @@ def save_model(model, save_dir): state_dict_config, ): cp_state_dict = {'model': model.state_dict()} + dcp.save_state_dict( cp_state_dict, dcp.FileSystemWriter(save_dir), ) + # For PyTorch versions ≥2.2: + # state_dict_options = dist_state_dict.StateDictOptions( + # cpu_offload=False, + # ) + # model_state_dict = dist_state_dict.get_model_state_dict( + # model, + # options=state_dict_options, + # ) + # cp_state_dict = {'model': model_state_dict} + # + # dcp.save( + # cp_state_dict, + # storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True), + # ) def load_model_singular(model, *args, **kwargs): @@ -149,11 +178,27 @@ def load_model(model, load_dir): state_dict_config, ): cp_state_dict = {'model': model.state_dict()} + dcp.load_state_dict( cp_state_dict, dcp.FileSystemReader(load_dir), ) model.load_state_dict(cp_state_dict['model']) + # For PyTorch versions ≥2.2: + # state_dict_options = dist_state_dict.StateDictOptions( + # cpu_offload=False, + # ) + # model_state_dict = dist_state_dict.get_model_state_dict( + # model, + # options=state_dict_options, + # ) + # cp_state_dict = {'model': model_state_dict} + # + # dcp.load( + # cp_state_dict, + # storage_reader=dcp.FileSystemReader(load_dir), + # ) + # dist_state_dict.set_model_state_dict(model, cp_state_dict['model']) def all_reduce_avg(tensor): @@ -276,9 +321,16 @@ def main(): train_dset, valid_dset, test_dset = prepare_datasets(args, device) model = build_model() + # For PyTorch versions ≥2.2, `mesh_1d` should be included to enable + # use of more modern implementations: + # mesh_1d = device_mesh.init_device_mesh( + # "cuda", + # (torch.distributed.get_world_size(),), + # ) model = fsdp.FullyShardedDataParallel( model, device_id=local_rank, + # device_mesh=mesh_1d, auto_wrap_policy=functools.partial( fsdp.wrap.size_based_auto_wrap_policy, # Wrap every 1B parameters.