diff --git a/pytorch-ddp-example/main.py b/pytorch-ddp-example/main.py
index bea13495b98dde179a962f13ed2c70e4cf5e8e31..aaecdd7c7b6a7e01d3afea59bc372f7cd9be47fc 100644
--- a/pytorch-ddp-example/main.py
+++ b/pytorch-ddp-example/main.py
@@ -201,7 +201,7 @@ def test_model(model, loss_func, test_dset, device):
 def main():
     args = parse_args()
 
-    torch.distributed.init_process_group(backend='nccl')
+    torch.distributed.init_process_group(backend='cpu:gloo,cuda:nccl')
 
     local_rank = get_local_rank()
     device = torch.device('cuda', local_rank)
diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index b167d18eafb9348f00b1c4677684204f4b41a76b..5843373362c3d5486f84836d7d4470a54f4a80d0 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -249,7 +249,7 @@ def test_model(model, loss_func, test_dset, device):
 def main():
     args = parse_args()
 
-    torch.distributed.init_process_group(backend='nccl')
+    torch.distributed.init_process_group(backend='cpu:gloo,cuda:nccl')
 
     local_rank = get_local_rank()
     device = torch.device('cuda', local_rank)