diff --git a/README.md b/README.md index e479f865e12624cc9f5ff2f619a3f4edf7932f15..3d3edcb27315960be44181d44d778a4f911b8676 100644 --- a/README.md +++ b/README.md @@ -528,13 +528,15 @@ number of processes we use. That is because we only configure the "local batch size", i.e., the batch size per process. The "global batch size" is obtained by multiplying the local batch size times the number of processes. If we scale up the number of processes, we obtain -a larger batch size; this, in turn, this will change what learning -rate we should use. A simple heuristic is to multiply the base -learning rate you would use for the local batch size by the square -root of the number of processes. This can be done by supplying the -`--scale-lr` argument so that it "just works" with an increasing -number of processes, but ideally you would tune the learning rate -manually for the global batch size you use. +a larger batch size; this, in turn, this will change what optimizer +parameters (such as learning rate) we should use. A +[simple-to-implement heuristic](https://arxiv.org/abs/2205.10287) is +to, for example, multiply the base learning rate you would use for the +local batch size by the square root of the number of processes. This +heuristic and others can be applied by supplying the +`--scale-optim-params` argument so that training "just works" with an +increasing number of processes, but ideally you would tune the +parameters manually for the global batch size you use. ## FSDP diff --git a/pytorch-ddp-example/main.py b/pytorch-ddp-example/main.py index 29f0b4dc530daf734ea1c9da68e4d8bff3c913b4..c7b1b1adb75ef92fbfa90238a1e5f19285321949 100644 --- a/pytorch-ddp-example/main.py +++ b/pytorch-ddp-example/main.py @@ -17,15 +17,15 @@ def parse_args(): help=( 'Step size or learning rate of the optimizer. ' 'May be scaled according to the number of processes. ' - '(See `--scale-lr`.)' + '(See `--scale-optim-params`.)' ), ) parser.add_argument( - '--scale-lr', + '--scale-optim-params', action='store_true', help=( - 'Whether the learning rate Will be scaled according to the ' - 'number of processes. (See `--batch-size`.)' + 'Whether select optimizer parameters will be scaled according to ' + 'the number of processes. (See `--batch-size`.)' ), ) parser.add_argument( @@ -221,10 +221,24 @@ def main(): loss_func = torch.nn.CrossEntropyLoss() lr = args.lr - if args.scale_lr: - # Scale learning rate according to number of processes. + # These are just the AdamW defaults. + adam_betas = (0.9, 0.999) + adam_eps = 1e-8 + if args.scale_optim_params: + # See https://arxiv.org/abs/2205.10287. + # Scale optimizer parameters according to number of processes. lr *= torch.distributed.get_world_size()**0.5 - opt = torch.optim.AdamW(model.parameters(), lr=lr) + adam_betas = ( + 1 - torch.distributed.get_world_size() * (1 - adam_betas[0]), + 1 - torch.distributed.get_world_size() * (1 - adam_betas[1]), + ) + adam_eps /= torch.distributed.get_world_size()**0.5 + opt = torch.optim.AdamW( + model.parameters(), + lr=lr, + betas=adam_betas, + eps=adam_eps, + ) # Maximum value of default dtype. min_valid_loss = torch.finfo(torch.get_default_dtype()).max diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index b49128ea2db6417a12499fbb93d5441899670a16..2e15d5f9e84e0cd9eb3ddaa9130745d148243329 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -19,15 +19,15 @@ def parse_args(): help=( 'Step size or learning rate of the optimizer. ' 'May be scaled according to the number of processes. ' - '(See `--scale-lr`.)' + '(See `--scale-optim-params`.)' ), ) parser.add_argument( - '--scale-lr', + '--scale-optim-params', action='store_true', help=( - 'Whether the learning rate Will be scaled according to the ' - 'number of processes. (See `--batch-size`.)' + 'Whether select optimizer parameters will be scaled according to ' + 'the number of processes. (See `--batch-size`.)' ), ) parser.add_argument( @@ -273,10 +273,24 @@ def main(): loss_func = torch.nn.CrossEntropyLoss() lr = args.lr - if args.scale_lr: - # Scale learning rate according to number of processes. + # These are just the AdamW defaults. + adam_betas = (0.9, 0.999) + adam_eps = 1e-8 + if args.scale_optim_params: + # See https://arxiv.org/abs/2205.10287. + # Scale optimizer parameters according to number of processes. lr *= torch.distributed.get_world_size()**0.5 - opt = torch.optim.AdamW(model.parameters(), lr=lr) + adam_betas = ( + 1 - torch.distributed.get_world_size() * (1 - adam_betas[0]), + 1 - torch.distributed.get_world_size() * (1 - adam_betas[1]), + ) + adam_eps /= torch.distributed.get_world_size()**0.5 + opt = torch.optim.AdamW( + model.parameters(), + lr=lr, + betas=adam_betas, + eps=adam_eps, + ) # Maximum value of default dtype. min_valid_loss = torch.finfo(torch.get_default_dtype()).max