diff --git a/pytorch-ddp-example/main.py b/pytorch-ddp-example/main.py index bea13495b98dde179a962f13ed2c70e4cf5e8e31..aaecdd7c7b6a7e01d3afea59bc372f7cd9be47fc 100644 --- a/pytorch-ddp-example/main.py +++ b/pytorch-ddp-example/main.py @@ -201,7 +201,7 @@ def test_model(model, loss_func, test_dset, device): def main(): args = parse_args() - torch.distributed.init_process_group(backend='nccl') + torch.distributed.init_process_group(backend='cpu:gloo,cuda:nccl') local_rank = get_local_rank() device = torch.device('cuda', local_rank) diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index b167d18eafb9348f00b1c4677684204f4b41a76b..5843373362c3d5486f84836d7d4470a54f4a80d0 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -249,7 +249,7 @@ def test_model(model, loss_func, test_dset, device): def main(): args = parse_args() - torch.distributed.init_process_group(backend='nccl') + torch.distributed.init_process_group(backend='cpu:gloo,cuda:nccl') local_rank = get_local_rank() device = torch.device('cuda', local_rank)