diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index 5843373362c3d5486f84836d7d4470a54f4a80d0..991e6fc95d38be1c4ce8c841e20529a16ff56c66 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -4,7 +4,7 @@ import os import time import torch -from torch.distributed import checkpoint as dist_checkpoint +from torch.distributed import checkpoint as dcp from torch.distributed import fsdp import torchvision @@ -123,9 +123,9 @@ def save_model(model, save_dir): state_dict_config, ): cp_state_dict = {'model': model.state_dict()} - dist_checkpoint.save_state_dict( + dcp.save_state_dict( cp_state_dict, - dist_checkpoint.FileSystemWriter(save_dir), + dcp.FileSystemWriter(save_dir), ) @@ -140,9 +140,9 @@ def load_model(model, load_dir): state_dict_config, ): cp_state_dict = {'model': model.state_dict()} - dist_checkpoint.load_state_dict( + dcp.load_state_dict( cp_state_dict, - dist_checkpoint.FileSystemReader(load_dir), + dcp.FileSystemReader(load_dir), ) model.load_state_dict(cp_state_dict['model'])