Skip to content
Snippets Groups Projects
Commit 68e35585 authored by Jan Ebert's avatar Jan Ebert
Browse files

Rename distributed checkpointing import

`dist_checkpoint` -> `dcp`

The DCP name is more canonical and it makes sense to make users aware of
this for improving their terminology.
parent 06888277
No related branches found
No related tags found
No related merge requests found
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import time import time
import torch import torch
from torch.distributed import checkpoint as dist_checkpoint from torch.distributed import checkpoint as dcp
from torch.distributed import fsdp from torch.distributed import fsdp
import torchvision import torchvision
...@@ -123,9 +123,9 @@ def save_model(model, save_dir): ...@@ -123,9 +123,9 @@ def save_model(model, save_dir):
state_dict_config, state_dict_config,
): ):
cp_state_dict = {'model': model.state_dict()} cp_state_dict = {'model': model.state_dict()}
dist_checkpoint.save_state_dict( dcp.save_state_dict(
cp_state_dict, cp_state_dict,
dist_checkpoint.FileSystemWriter(save_dir), dcp.FileSystemWriter(save_dir),
) )
...@@ -140,9 +140,9 @@ def load_model(model, load_dir): ...@@ -140,9 +140,9 @@ def load_model(model, load_dir):
state_dict_config, state_dict_config,
): ):
cp_state_dict = {'model': model.state_dict()} cp_state_dict = {'model': model.state_dict()}
dist_checkpoint.load_state_dict( dcp.load_state_dict(
cp_state_dict, cp_state_dict,
dist_checkpoint.FileSystemReader(load_dir), dcp.FileSystemReader(load_dir),
) )
model.load_state_dict(cp_state_dict['model']) model.load_state_dict(cp_state_dict['model'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment