Skip to content
Snippets Groups Projects
Commit ceb61d4f authored by Jan Ebert's avatar Jan Ebert
Browse files

Add code for later PyTorch versions

We could put some version parsing and appropriate if-guards into the
code, but I'm afraid of the increase in complexity.
parent b4c87a7f
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,11 @@ import time ...@@ -5,7 +5,11 @@ import time
import torch import torch
from torch.distributed import checkpoint as dcp from torch.distributed import checkpoint as dcp
# The commented modules require PyTorch ≥2.2 and enable more modern
# APIs.
# from torch.distributed import device_mesh
from torch.distributed import fsdp from torch.distributed import fsdp
# from torch.distributed.checkpoint import state_dict as dist_state_dict
import torchvision import torchvision
...@@ -104,6 +108,16 @@ def save_model_singular(model, *args, **kwargs): ...@@ -104,6 +108,16 @@ def save_model_singular(model, *args, **kwargs):
save_policy, save_policy,
): ):
cpu_state = model.state_dict() cpu_state = model.state_dict()
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# full_state_dict=True,
# cpu_offload=True,
# )
# cpu_state = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
# We do *not* want to write to the same location with multiple # We do *not* want to write to the same location with multiple
# processes at the same time. # processes at the same time.
if is_root_process(): if is_root_process():
...@@ -123,10 +137,25 @@ def save_model(model, save_dir): ...@@ -123,10 +137,25 @@ def save_model(model, save_dir):
state_dict_config, state_dict_config,
): ):
cp_state_dict = {'model': model.state_dict()} cp_state_dict = {'model': model.state_dict()}
dcp.save_state_dict( dcp.save_state_dict(
cp_state_dict, cp_state_dict,
dcp.FileSystemWriter(save_dir), dcp.FileSystemWriter(save_dir),
) )
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# cpu_offload=False,
# )
# model_state_dict = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
# cp_state_dict = {'model': model_state_dict}
#
# dcp.save(
# cp_state_dict,
# storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True),
# )
def load_model_singular(model, *args, **kwargs): def load_model_singular(model, *args, **kwargs):
...@@ -149,11 +178,27 @@ def load_model(model, load_dir): ...@@ -149,11 +178,27 @@ def load_model(model, load_dir):
state_dict_config, state_dict_config,
): ):
cp_state_dict = {'model': model.state_dict()} cp_state_dict = {'model': model.state_dict()}
dcp.load_state_dict( dcp.load_state_dict(
cp_state_dict, cp_state_dict,
dcp.FileSystemReader(load_dir), dcp.FileSystemReader(load_dir),
) )
model.load_state_dict(cp_state_dict['model']) model.load_state_dict(cp_state_dict['model'])
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# cpu_offload=False,
# )
# model_state_dict = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
# cp_state_dict = {'model': model_state_dict}
#
# dcp.load(
# cp_state_dict,
# storage_reader=dcp.FileSystemReader(load_dir),
# )
# dist_state_dict.set_model_state_dict(model, cp_state_dict['model'])
def all_reduce_avg(tensor): def all_reduce_avg(tensor):
...@@ -276,9 +321,16 @@ def main(): ...@@ -276,9 +321,16 @@ def main():
train_dset, valid_dset, test_dset = prepare_datasets(args, device) train_dset, valid_dset, test_dset = prepare_datasets(args, device)
model = build_model() model = build_model()
# For PyTorch versions ≥2.2, `mesh_1d` should be included to enable
# use of more modern implementations:
# mesh_1d = device_mesh.init_device_mesh(
# "cuda",
# (torch.distributed.get_world_size(),),
# )
model = fsdp.FullyShardedDataParallel( model = fsdp.FullyShardedDataParallel(
model, model,
device_id=local_rank, device_id=local_rank,
# device_mesh=mesh_1d,
auto_wrap_policy=functools.partial( auto_wrap_policy=functools.partial(
fsdp.wrap.size_based_auto_wrap_policy, fsdp.wrap.size_based_auto_wrap_policy,
# Wrap every 1B parameters. # Wrap every 1B parameters.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment