Skip to content
Snippets Groups Projects

update to pytorch 2.5

1 file
+ 38
78
Compare changes
  • Side-by-side
  • Inline
+ 38
78
@@ -5,11 +5,9 @@ import time
import torch
from torch.distributed import checkpoint as dcp
# The commented modules require PyTorch ≥2.2 and enable more modern
# APIs.
# from torch.distributed import device_mesh
from torch.distributed import device_mesh
from torch.distributed import fsdp
# from torch.distributed.checkpoint import state_dict as dist_state_dict
from torch.distributed.checkpoint import state_dict as dist_state_dict
import torchvision
@@ -100,23 +98,14 @@ def save_model_singular(model, *args, **kwargs):
other given arguments to `torch.save` to save the model, but only on
the root process.
"""
save_policy = fsdp.FullStateDictConfig(
offload_to_cpu=True, rank0_only=True)
with fsdp.FullyShardedDataParallel.state_dict_type(
model,
fsdp.StateDictType.FULL_STATE_DICT,
save_policy,
):
cpu_state = model.state_dict()
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# full_state_dict=True,
# cpu_offload=True,
# )
# cpu_state = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
state_dict_options = dist_state_dict.StateDictOptions(
full_state_dict=True,
cpu_offload=True,
)
cpu_state = dist_state_dict.get_model_state_dict(
model,
options=state_dict_options,
)
# We do *not* want to write to the same location with multiple
# processes at the same time.
@@ -130,32 +119,19 @@ def save_model(model, save_dir):
distributed checkpoint means that the checkpoint will be split into
individual files, one for each process.
"""
state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False)
with fsdp.FullyShardedDataParallel.state_dict_type(
model,
fsdp.StateDictType.SHARDED_STATE_DICT,
state_dict_config,
):
cp_state_dict = {'model': model.state_dict()}
dcp.save_state_dict(
state_dict_options = dist_state_dict.StateDictOptions(
cpu_offload=False,
)
model_state_dict = dist_state_dict.get_model_state_dict(
model,
options=state_dict_options,
)
cp_state_dict = {'model': model_state_dict}
dcp.save(
cp_state_dict,
dcp.FileSystemWriter(save_dir),
storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True),
)
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# cpu_offload=False,
# )
# model_state_dict = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
# cp_state_dict = {'model': model_state_dict}
#
# dcp.save(
# cp_state_dict,
# storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True),
# )
def load_model_singular(model, *args, **kwargs):
@@ -171,34 +147,20 @@ def load_model(model, load_dir):
"""Set the given model's state dictionary in-place from the given
distributed checkpoint directory.
"""
state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False)
with fsdp.FullyShardedDataParallel.state_dict_type(
model,
fsdp.StateDictType.SHARDED_STATE_DICT,
state_dict_config,
):
cp_state_dict = {'model': model.state_dict()}
dcp.load_state_dict(
state_dict_options = dist_state_dict.StateDictOptions(
cpu_offload=False,
)
model_state_dict = dist_state_dict.get_model_state_dict(
model,
options=state_dict_options,
)
cp_state_dict = {'model': model_state_dict}
dcp.load(
cp_state_dict,
dcp.FileSystemReader(load_dir),
storage_reader=dcp.FileSystemReader(load_dir),
)
model.load_state_dict(cp_state_dict['model'])
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# cpu_offload=False,
# )
# model_state_dict = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
# cp_state_dict = {'model': model_state_dict}
#
# dcp.load(
# cp_state_dict,
# storage_reader=dcp.FileSystemReader(load_dir),
# )
# dist_state_dict.set_model_state_dict(model, cp_state_dict['model'])
dist_state_dict.set_model_state_dict(model, cp_state_dict['model'])
def all_reduce_avg(tensor):
@@ -321,16 +283,14 @@ def main():
train_dset, valid_dset, test_dset = prepare_datasets(args, device)
model = build_model()
# For PyTorch versions ≥2.2, `mesh_1d` should be included to enable
# use of more modern implementations:
# mesh_1d = device_mesh.init_device_mesh(
# "cuda",
# (torch.distributed.get_world_size(),),
# )
mesh_1d = device_mesh.init_device_mesh(
"cuda",
(torch.distributed.get_world_size(),),
)
model = fsdp.FullyShardedDataParallel(
model,
device_id=local_rank,
# device_mesh=mesh_1d,
device_mesh=mesh_1d,
auto_wrap_policy=functools.partial(
fsdp.wrap.size_based_auto_wrap_policy,
# Wrap every 1B parameters.
Loading