From ceb61d4f9c2644123d9f6f73e1a94a74151f147d Mon Sep 17 00:00:00 2001 From: janEbert <janpublicebert@posteo.net> Date: Thu, 14 Nov 2024 15:09:42 +0100 Subject: [PATCH] Add code for later PyTorch versions We could put some version parsing and appropriate if-guards into the code, but I'm afraid of the increase in complexity. --- pytorch-fsdp-example/main.py | 52 ++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index 44c7a24..d86328b 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -5,7 +5,11 @@ import time import torch from torch.distributed import checkpoint as dcp +# The commented modules require PyTorch ≥2.2 and enable more modern +# APIs. +# from torch.distributed import device_mesh from torch.distributed import fsdp +# from torch.distributed.checkpoint import state_dict as dist_state_dict import torchvision @@ -104,6 +108,16 @@ def save_model_singular(model, *args, **kwargs): save_policy, ): cpu_state = model.state_dict() + # For PyTorch versions ≥2.2: + # state_dict_options = dist_state_dict.StateDictOptions( + # full_state_dict=True, + # cpu_offload=True, + # ) + # cpu_state = dist_state_dict.get_model_state_dict( + # model, + # options=state_dict_options, + # ) + # We do *not* want to write to the same location with multiple # processes at the same time. if is_root_process(): @@ -123,10 +137,25 @@ def save_model(model, save_dir): state_dict_config, ): cp_state_dict = {'model': model.state_dict()} + dcp.save_state_dict( cp_state_dict, dcp.FileSystemWriter(save_dir), ) + # For PyTorch versions ≥2.2: + # state_dict_options = dist_state_dict.StateDictOptions( + # cpu_offload=False, + # ) + # model_state_dict = dist_state_dict.get_model_state_dict( + # model, + # options=state_dict_options, + # ) + # cp_state_dict = {'model': model_state_dict} + # + # dcp.save( + # cp_state_dict, + # storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True), + # ) def load_model_singular(model, *args, **kwargs): @@ -149,11 +178,27 @@ def load_model(model, load_dir): state_dict_config, ): cp_state_dict = {'model': model.state_dict()} + dcp.load_state_dict( cp_state_dict, dcp.FileSystemReader(load_dir), ) model.load_state_dict(cp_state_dict['model']) + # For PyTorch versions ≥2.2: + # state_dict_options = dist_state_dict.StateDictOptions( + # cpu_offload=False, + # ) + # model_state_dict = dist_state_dict.get_model_state_dict( + # model, + # options=state_dict_options, + # ) + # cp_state_dict = {'model': model_state_dict} + # + # dcp.load( + # cp_state_dict, + # storage_reader=dcp.FileSystemReader(load_dir), + # ) + # dist_state_dict.set_model_state_dict(model, cp_state_dict['model']) def all_reduce_avg(tensor): @@ -276,9 +321,16 @@ def main(): train_dset, valid_dset, test_dset = prepare_datasets(args, device) model = build_model() + # For PyTorch versions ≥2.2, `mesh_1d` should be included to enable + # use of more modern implementations: + # mesh_1d = device_mesh.init_device_mesh( + # "cuda", + # (torch.distributed.get_world_size(),), + # ) model = fsdp.FullyShardedDataParallel( model, device_id=local_rank, + # device_mesh=mesh_1d, auto_wrap_policy=functools.partial( fsdp.wrap.size_based_auto_wrap_policy, # Wrap every 1B parameters. -- GitLab