From c941cc9c2379fbd43b5b223620b0bc2cd6154709 Mon Sep 17 00:00:00 2001
From: janEbert <janpublicebert@posteo.net>
Date: Wed, 9 Oct 2024 10:13:19 +0200
Subject: [PATCH] Also scale other optimizer parameters

This is just a minor complexity increase, but should give a lot of
help/intuition on what scaling up entails.

Also finally properly reference the work in the code/README.
---
 README.md                    | 16 +++++++++-------
 pytorch-ddp-example/main.py  | 28 +++++++++++++++++++++-------
 pytorch-fsdp-example/main.py | 28 +++++++++++++++++++++-------
 3 files changed, 51 insertions(+), 21 deletions(-)

diff --git a/README.md b/README.md
index e479f86..3d3edcb 100644
--- a/README.md
+++ b/README.md
@@ -528,13 +528,15 @@ number of processes we use. That is because we only configure the
 "local batch size", i.e., the batch size per process. The "global
 batch size" is obtained by multiplying the local batch size times the
 number of processes. If we scale up the number of processes, we obtain
-a larger batch size; this, in turn, this will change what learning
-rate we should use. A simple heuristic is to multiply the base
-learning rate you would use for the local batch size by the square
-root of the number of processes. This can be done by supplying the
-`--scale-lr` argument so that it "just works" with an increasing
-number of processes, but ideally you would tune the learning rate
-manually for the global batch size you use.
+a larger batch size; this, in turn, this will change what optimizer
+parameters (such as learning rate) we should use. A
+[simple-to-implement heuristic](https://arxiv.org/abs/2205.10287) is
+to, for example, multiply the base learning rate you would use for the
+local batch size by the square root of the number of processes. This
+heuristic and others can be applied by supplying the
+`--scale-optim-params` argument so that training "just works" with an
+increasing number of processes, but ideally you would tune the
+parameters manually for the global batch size you use.
 
 ## FSDP
 
diff --git a/pytorch-ddp-example/main.py b/pytorch-ddp-example/main.py
index 29f0b4d..c7b1b1a 100644
--- a/pytorch-ddp-example/main.py
+++ b/pytorch-ddp-example/main.py
@@ -17,15 +17,15 @@ def parse_args():
         help=(
             'Step size or learning rate of the optimizer. '
             'May be scaled according to the number of processes. '
-            '(See `--scale-lr`.)'
+            '(See `--scale-optim-params`.)'
         ),
     )
     parser.add_argument(
-        '--scale-lr',
+        '--scale-optim-params',
         action='store_true',
         help=(
-            'Whether the learning rate Will be scaled according to the '
-            'number of processes. (See `--batch-size`.)'
+            'Whether select optimizer parameters will be scaled according to '
+            'the number of processes. (See `--batch-size`.)'
         ),
     )
     parser.add_argument(
@@ -221,10 +221,24 @@ def main():
     loss_func = torch.nn.CrossEntropyLoss()
 
     lr = args.lr
-    if args.scale_lr:
-        # Scale learning rate according to number of processes.
+    # These are just the AdamW defaults.
+    adam_betas = (0.9, 0.999)
+    adam_eps = 1e-8
+    if args.scale_optim_params:
+        # See https://arxiv.org/abs/2205.10287.
+        # Scale optimizer parameters according to number of processes.
         lr *= torch.distributed.get_world_size()**0.5
-    opt = torch.optim.AdamW(model.parameters(), lr=lr)
+        adam_betas = (
+            1 - torch.distributed.get_world_size() * (1 - adam_betas[0]),
+            1 - torch.distributed.get_world_size() * (1 - adam_betas[1]),
+        )
+        adam_eps /= torch.distributed.get_world_size()**0.5
+    opt = torch.optim.AdamW(
+        model.parameters(),
+        lr=lr,
+        betas=adam_betas,
+        eps=adam_eps,
+    )
 
     # Maximum value of default dtype.
     min_valid_loss = torch.finfo(torch.get_default_dtype()).max
diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py
index b49128e..2e15d5f 100644
--- a/pytorch-fsdp-example/main.py
+++ b/pytorch-fsdp-example/main.py
@@ -19,15 +19,15 @@ def parse_args():
         help=(
             'Step size or learning rate of the optimizer. '
             'May be scaled according to the number of processes. '
-            '(See `--scale-lr`.)'
+            '(See `--scale-optim-params`.)'
         ),
     )
     parser.add_argument(
-        '--scale-lr',
+        '--scale-optim-params',
         action='store_true',
         help=(
-            'Whether the learning rate Will be scaled according to the '
-            'number of processes. (See `--batch-size`.)'
+            'Whether select optimizer parameters will be scaled according to '
+            'the number of processes. (See `--batch-size`.)'
         ),
     )
     parser.add_argument(
@@ -273,10 +273,24 @@ def main():
     loss_func = torch.nn.CrossEntropyLoss()
 
     lr = args.lr
-    if args.scale_lr:
-        # Scale learning rate according to number of processes.
+    # These are just the AdamW defaults.
+    adam_betas = (0.9, 0.999)
+    adam_eps = 1e-8
+    if args.scale_optim_params:
+        # See https://arxiv.org/abs/2205.10287.
+        # Scale optimizer parameters according to number of processes.
         lr *= torch.distributed.get_world_size()**0.5
-    opt = torch.optim.AdamW(model.parameters(), lr=lr)
+        adam_betas = (
+            1 - torch.distributed.get_world_size() * (1 - adam_betas[0]),
+            1 - torch.distributed.get_world_size() * (1 - adam_betas[1]),
+        )
+        adam_eps /= torch.distributed.get_world_size()**0.5
+    opt = torch.optim.AdamW(
+        model.parameters(),
+        lr=lr,
+        betas=adam_betas,
+        eps=adam_eps,
+    )
 
     # Maximum value of default dtype.
     min_valid_loss = torch.finfo(torch.get_default_dtype()).max
-- 
GitLab