From bd050edf07f7ad509ec585bd7ca25ff14e2a0b5c Mon Sep 17 00:00:00 2001
From: janEbert <janpublicebert@posteo.net>
Date: Wed, 16 Oct 2024 20:27:13 +0200
Subject: [PATCH] Specify Gloo backend for CPU

Should make it easier for users to customize, e.g., to use asynchronous
checkpointing.
---
 pytorch-ddp-example/main.py  | 2 +-
 pytorch-fsdp-example/main.py | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/pytorch-ddp-example/main.py b/pytorch-ddp-example/main.py
index bea1349..aaecdd7 100644
--- a/pytorch-ddp-example/main.py
+++ b/pytorch-ddp-example/main.py
@@ -201,7 +201,7 @@ def test_model(model, loss_func, test_dset, device):
 def main():
     args = parse_args()
 
-    torch.distributed.init_process_group(backend='nccl')
+    torch.distributed.init_process_group(backend='cpu:gloo,cuda:nccl')
 
     local_rank = get_local_rank()
     device = torch.device('cuda', local_rank)
diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index b167d18..5843373 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -249,7 +249,7 @@ def test_model(model, loss_func, test_dset, device):
 def main():
     args = parse_args()
 
-    torch.distributed.init_process_group(backend='nccl')
+    torch.distributed.init_process_group(backend='cpu:gloo,cuda:nccl')
 
     local_rank = get_local_rank()
     device = torch.device('cuda', local_rank)
-- 
GitLab