Skip to content
Snippets Groups Projects
Commit ceb61d4f authored by Jan Ebert's avatar Jan Ebert
Browse files

Add code for later PyTorch versions

We could put some version parsing and appropriate if-guards into the
code, but I'm afraid of the increase in complexity.
parent b4c87a7f
Branches
Tags release-v0.1.3
No related merge requests found
......@@ -5,7 +5,11 @@ import time
import torch
from torch.distributed import checkpoint as dcp
# The commented modules require PyTorch ≥2.2 and enable more modern
# APIs.
# from torch.distributed import device_mesh
from torch.distributed import fsdp
# from torch.distributed.checkpoint import state_dict as dist_state_dict
import torchvision
......@@ -104,6 +108,16 @@ def save_model_singular(model, *args, **kwargs):
save_policy,
):
cpu_state = model.state_dict()
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# full_state_dict=True,
# cpu_offload=True,
# )
# cpu_state = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
# We do *not* want to write to the same location with multiple
# processes at the same time.
if is_root_process():
......@@ -123,10 +137,25 @@ def save_model(model, save_dir):
state_dict_config,
):
cp_state_dict = {'model': model.state_dict()}
dcp.save_state_dict(
cp_state_dict,
dcp.FileSystemWriter(save_dir),
)
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# cpu_offload=False,
# )
# model_state_dict = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
# cp_state_dict = {'model': model_state_dict}
#
# dcp.save(
# cp_state_dict,
# storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True),
# )
def load_model_singular(model, *args, **kwargs):
......@@ -149,11 +178,27 @@ def load_model(model, load_dir):
state_dict_config,
):
cp_state_dict = {'model': model.state_dict()}
dcp.load_state_dict(
cp_state_dict,
dcp.FileSystemReader(load_dir),
)
model.load_state_dict(cp_state_dict['model'])
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# cpu_offload=False,
# )
# model_state_dict = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
# cp_state_dict = {'model': model_state_dict}
#
# dcp.load(
# cp_state_dict,
# storage_reader=dcp.FileSystemReader(load_dir),
# )
# dist_state_dict.set_model_state_dict(model, cp_state_dict['model'])
def all_reduce_avg(tensor):
......@@ -276,9 +321,16 @@ def main():
train_dset, valid_dset, test_dset = prepare_datasets(args, device)
model = build_model()
# For PyTorch versions ≥2.2, `mesh_1d` should be included to enable
# use of more modern implementations:
# mesh_1d = device_mesh.init_device_mesh(
# "cuda",
# (torch.distributed.get_world_size(),),
# )
model = fsdp.FullyShardedDataParallel(
model,
device_id=local_rank,
# device_mesh=mesh_1d,
auto_wrap_policy=functools.partial(
fsdp.wrap.size_based_auto_wrap_policy,
# Wrap every 1B parameters.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment