diff --git a/README.md b/README.md index d47b479a078619814288a35ef4677658af8b78cc..b07ce85d912bcd7568eb76f4c1909773123cef82 100644 --- a/README.md +++ b/README.md @@ -610,3 +610,31 @@ that collects the full model on the CPU and then saves it in a single checkpoint file which can then be loaded in a single process. Keep in mind that this way of checkpointing is slower and limited by CPU memory. + +### HSDP + +"Hybrid sharded data parallel" is a way to reduce communication in +FSDP training when your model does not need to be split across all +processes. This is achieved by creating independent replicas/copies of +the fully sharded model and feeding them distinct data, just like how +DDP does it. Similarly, the gradients obtained by these replicas on +the different input batches are averaged across the replicas. HSDP is +thus a combination of FSDP and DDP. Usually, it is recommended to use +as many replicas as nodes, so that the model is only sharded inside +nodes, but your mileage may vary. Especially if the model becomes too +large, you will have to split it up further than is possible on just +one node. + +Communication is reduced because we do not have to execute the +expensive collect-discard steps for each FSDP unit's sharded +parameters across all processes; we only execute these expensive steps +in the limited number of shards per replica and execute a less +expensive gradient averaging step across the processes instead. +Additionally, communication inside a node is usually much faster than +across nodes, meaning we keep the expensive communication where +bandwidth is higher and the less expensive communication where +bandwidth is lower. + +To enable HSDP, pass a number of desired FSDP replicas using the +`--num-fsdp-replicas` argument. If this argument is not given, +standard FSDP is used.