From e72c89766a26f0263b1e852c6d92df5b7969bf99 Mon Sep 17 00:00:00 2001
From: janEbert <janpublicebert@posteo.net>
Date: Tue, 9 Jul 2024 17:10:57 +0200
Subject: [PATCH] Use square root LR scaling

Also adapt documentation regarding the default change.
---
 README.md                    | 13 ++++++-------
 pytorch-ddp-example/main.py  |  2 +-
 pytorch-fsdp-example/main.py |  2 +-
 3 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/README.md b/README.md
index 43869f0..b7c7335 100644
--- a/README.md
+++ b/README.md
@@ -446,13 +446,12 @@ number of processes we use. That is because we only configure the
 batch size" is obtained by multiplying the local batch size times the
 number of processes. If we scale up the number of processes, we obtain
 a larger batch size; this, in turn, this will change what learning
-rate we should use. A very simple heuristic is to just scale the base
-learning rate you would use for the local batch size proportional to
-the number of processes: for example, we just multiply the base
-learning rate times the number of processes. This is automatically
-done in the code so that it "just works" with a large number of
-processes, but ideally you would tune the learning rate manually for
-the global batch size you use.
+rate we should use. A simple heuristic is to multiply the base
+learning rate you would use for the local batch size by the square
+root of the number of processes. This can be done by supplying the
+`--scale-lr` argument so that it "just works" with an increasing
+number of processes, but ideally you would tune the learning rate
+manually for the global batch size you use.
 
 ## FSDP
 
diff --git a/pytorch-ddp-example/main.py b/pytorch-ddp-example/main.py
index 14d0774..29f0b4d 100644
--- a/pytorch-ddp-example/main.py
+++ b/pytorch-ddp-example/main.py
@@ -223,7 +223,7 @@ def main():
     lr = args.lr
     if args.scale_lr:
         # Scale learning rate according to number of processes.
-        lr *= torch.distributed.get_world_size()
+        lr *= torch.distributed.get_world_size()**0.5
     opt = torch.optim.AdamW(model.parameters(), lr=lr)
 
     # Maximum value of default dtype.
diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index f8ebfbf..b49128e 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -275,7 +275,7 @@ def main():
     lr = args.lr
     if args.scale_lr:
         # Scale learning rate according to number of processes.
-        lr *= torch.distributed.get_world_size()
+        lr *= torch.distributed.get_world_size()**0.5
     opt = torch.optim.AdamW(model.parameters(), lr=lr)
 
     # Maximum value of default dtype.
-- 
GitLab