diff --git a/README.md b/README.md index 43869f0b2b8b6ca6c3465993d5536e8238ed10a9..b7c7335d86c846a2c4e33653067f48e3e00d4ce9 100644 --- a/README.md +++ b/README.md @@ -446,13 +446,12 @@ number of processes we use. That is because we only configure the batch size" is obtained by multiplying the local batch size times the number of processes. If we scale up the number of processes, we obtain a larger batch size; this, in turn, this will change what learning -rate we should use. A very simple heuristic is to just scale the base -learning rate you would use for the local batch size proportional to -the number of processes: for example, we just multiply the base -learning rate times the number of processes. This is automatically -done in the code so that it "just works" with a large number of -processes, but ideally you would tune the learning rate manually for -the global batch size you use. +rate we should use. A simple heuristic is to multiply the base +learning rate you would use for the local batch size by the square +root of the number of processes. This can be done by supplying the +`--scale-lr` argument so that it "just works" with an increasing +number of processes, but ideally you would tune the learning rate +manually for the global batch size you use. ## FSDP diff --git a/pytorch-ddp-example/main.py b/pytorch-ddp-example/main.py index 14d07749d3b7f5f73aadb84d795ec066febaa34d..29f0b4dc530daf734ea1c9da68e4d8bff3c913b4 100644 --- a/pytorch-ddp-example/main.py +++ b/pytorch-ddp-example/main.py @@ -223,7 +223,7 @@ def main(): lr = args.lr if args.scale_lr: # Scale learning rate according to number of processes. - lr *= torch.distributed.get_world_size() + lr *= torch.distributed.get_world_size()**0.5 opt = torch.optim.AdamW(model.parameters(), lr=lr) # Maximum value of default dtype. diff --git a/pytorch-fsdp-example/main.py b/pytorch-fsdp-example/main.py index f8ebfbfcb718154686749898fca193e13de32f6b..b49128ea2db6417a12499fbb93d5441899670a16 100644 --- a/pytorch-fsdp-example/main.py +++ b/pytorch-fsdp-example/main.py @@ -275,7 +275,7 @@ def main(): lr = args.lr if args.scale_lr: # Scale learning rate according to number of processes. - lr *= torch.distributed.get_world_size() + lr *= torch.distributed.get_world_size()**0.5 opt = torch.optim.AdamW(model.parameters(), lr=lr) # Maximum value of default dtype.