Skip to content
Snippets Groups Projects
Commit 6303cd82 authored by Jose Ignacio Robledo's avatar Jose Ignacio Robledo Committed by Jan Ebert
Browse files

Fix deprecation warnings

Since the default PyTorch module is now 2.5, we can finally integrate
the previously commented changed, enabling newer APIs and fixing
deprecation warnings.

Closes !1.
parent c40183f1
No related branches found
No related tags found
No related merge requests found
...@@ -5,11 +5,9 @@ import time ...@@ -5,11 +5,9 @@ import time
import torch import torch
from torch.distributed import checkpoint as dcp from torch.distributed import checkpoint as dcp
# The commented modules require PyTorch ≥2.2 and enable more modern from torch.distributed import device_mesh
# APIs.
# from torch.distributed import device_mesh
from torch.distributed import fsdp from torch.distributed import fsdp
# from torch.distributed.checkpoint import state_dict as dist_state_dict from torch.distributed.checkpoint import state_dict as dist_state_dict
import torchvision import torchvision
...@@ -100,23 +98,14 @@ def save_model_singular(model, *args, **kwargs): ...@@ -100,23 +98,14 @@ def save_model_singular(model, *args, **kwargs):
other given arguments to `torch.save` to save the model, but only on other given arguments to `torch.save` to save the model, but only on
the root process. the root process.
""" """
save_policy = fsdp.FullStateDictConfig( state_dict_options = dist_state_dict.StateDictOptions(
offload_to_cpu=True, rank0_only=True) full_state_dict=True,
with fsdp.FullyShardedDataParallel.state_dict_type( cpu_offload=True,
)
cpu_state = dist_state_dict.get_model_state_dict(
model, model,
fsdp.StateDictType.FULL_STATE_DICT, options=state_dict_options,
save_policy, )
):
cpu_state = model.state_dict()
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# full_state_dict=True,
# cpu_offload=True,
# )
# cpu_state = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
# We do *not* want to write to the same location with multiple # We do *not* want to write to the same location with multiple
# processes at the same time. # processes at the same time.
...@@ -130,32 +119,19 @@ def save_model(model, save_dir): ...@@ -130,32 +119,19 @@ def save_model(model, save_dir):
distributed checkpoint means that the checkpoint will be split into distributed checkpoint means that the checkpoint will be split into
individual files, one for each process. individual files, one for each process.
""" """
state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False) state_dict_options = dist_state_dict.StateDictOptions(
with fsdp.FullyShardedDataParallel.state_dict_type( cpu_offload=False,
)
model_state_dict = dist_state_dict.get_model_state_dict(
model, model,
fsdp.StateDictType.SHARDED_STATE_DICT, options=state_dict_options,
state_dict_config, )
): cp_state_dict = {'model': model_state_dict}
cp_state_dict = {'model': model.state_dict()}
dcp.save_state_dict( dcp.save(
cp_state_dict, cp_state_dict,
dcp.FileSystemWriter(save_dir), storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True),
) )
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# cpu_offload=False,
# )
# model_state_dict = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
# cp_state_dict = {'model': model_state_dict}
#
# dcp.save(
# cp_state_dict,
# storage_writer=dcp.FileSystemWriter(save_dir, overwrite=True),
# )
def load_model_singular(model, *args, **kwargs): def load_model_singular(model, *args, **kwargs):
...@@ -171,34 +147,20 @@ def load_model(model, load_dir): ...@@ -171,34 +147,20 @@ def load_model(model, load_dir):
"""Set the given model's state dictionary in-place from the given """Set the given model's state dictionary in-place from the given
distributed checkpoint directory. distributed checkpoint directory.
""" """
state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False) state_dict_options = dist_state_dict.StateDictOptions(
with fsdp.FullyShardedDataParallel.state_dict_type( cpu_offload=False,
)
model_state_dict = dist_state_dict.get_model_state_dict(
model, model,
fsdp.StateDictType.SHARDED_STATE_DICT, options=state_dict_options,
state_dict_config, )
): cp_state_dict = {'model': model_state_dict}
cp_state_dict = {'model': model.state_dict()}
dcp.load_state_dict( dcp.load(
cp_state_dict, cp_state_dict,
dcp.FileSystemReader(load_dir), storage_reader=dcp.FileSystemReader(load_dir),
) )
model.load_state_dict(cp_state_dict['model']) dist_state_dict.set_model_state_dict(model, cp_state_dict['model'])
# For PyTorch versions ≥2.2:
# state_dict_options = dist_state_dict.StateDictOptions(
# cpu_offload=False,
# )
# model_state_dict = dist_state_dict.get_model_state_dict(
# model,
# options=state_dict_options,
# )
# cp_state_dict = {'model': model_state_dict}
#
# dcp.load(
# cp_state_dict,
# storage_reader=dcp.FileSystemReader(load_dir),
# )
# dist_state_dict.set_model_state_dict(model, cp_state_dict['model'])
def all_reduce_avg(tensor): def all_reduce_avg(tensor):
...@@ -321,16 +283,14 @@ def main(): ...@@ -321,16 +283,14 @@ def main():
train_dset, valid_dset, test_dset = prepare_datasets(args, device) train_dset, valid_dset, test_dset = prepare_datasets(args, device)
model = build_model() model = build_model()
# For PyTorch versions ≥2.2, `mesh_1d` should be included to enable mesh_1d = device_mesh.init_device_mesh(
# use of more modern implementations: "cuda",
# mesh_1d = device_mesh.init_device_mesh( (torch.distributed.get_world_size(),),
# "cuda", )
# (torch.distributed.get_world_size(),),
# )
model = fsdp.FullyShardedDataParallel( model = fsdp.FullyShardedDataParallel(
model, model,
device_id=local_rank, device_id=local_rank,
# device_mesh=mesh_1d, device_mesh=mesh_1d,
auto_wrap_policy=functools.partial( auto_wrap_policy=functools.partial(
fsdp.wrap.size_based_auto_wrap_policy, fsdp.wrap.size_based_auto_wrap_policy,
# Wrap every 1B parameters. # Wrap every 1B parameters.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment