diff --git a/pytorch-ddp-example/main.py b/pytorch-ddp-example/main.py index c7b1b1adb75ef92fbfa90238a1e5f19285321949..bea13495b98dde179a962f13ed2c70e4cf5e8e31 100644 --- a/pytorch-ddp-example/main.py +++ b/pytorch-ddp-example/main.py @@ -289,6 +289,8 @@ def main(): print0('Final test loss:', test_loss) save0(model, 'model-final.pt') + torch.distributed.destroy_process_group() + if __name__ == '__main__': main() diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index 2e15d5f9e84e0cd9eb3ddaa9130745d148243329..b167d18eafb9348f00b1c4677684204f4b41a76b 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -341,6 +341,8 @@ def main(): print0('Final test loss:', test_loss) save_model(model, 'model-final') + torch.distributed.destroy_process_group() + if __name__ == '__main__': main()