From 7274481bc4d0a525c88f77def75b1ec70939ebb4 Mon Sep 17 00:00:00 2001 From: janEbert <janpublicebert@posteo.net> Date: Tue, 9 Jul 2024 17:21:19 +0200 Subject: [PATCH] Explain FSDP --- README.md | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b7c7335..74a9ccf 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ https://medium.com/pytorch/pytorch-data-parallel-best-practices-on-google-cloud- - [DDP](#ddp) - [DDP considerations](#ddp-considerations) - [FSDP](#fsdp) + - [FSDP considerations](#fsdp-considerations) ## General @@ -455,4 +456,62 @@ manually for the global batch size you use. ## FSDP -Currently missing. +This example will use +[`FullyShardedDataParallel`](https://pytorch.org/docs/stable/fsdp.html) (FSDP) +to do model- and data-parallel training. This means that we will store +different parts of the model on different GPUs and also evaluate it on +different batches on each GPU. Similarly, the gradients and optimizer +states are also distributed. + +For initialization, FSDP first defines a hierarchy of distinct, but +possibly nested, submodules ("units") for the model. This process is +also called "wrapping" in FSDP terminology and can be controlled using +the `auto_wrap_policy` argument to `FullyShardedDataParallel`. The +parameters in each unit are then split and distributed ("sharded", or +scattered) to all GPUs. In the end, each GPU contains its own, +distinct model shard. + +Whenever we do a forward pass with the model, we sequentially pass +through units in the following way: FSDP automatically collects the +parameters in the unit that is currently being processed across all +GPUs, so that each GPU contains the full unit that is being processed. +Remember that in addition, the GPU also contains other parts of the +model from other units, but in the sharded form. Each GPU can then +execute the unit's forward pass on its own input batch to receive its +own local activations. The GPU then discards the additional parameters +that it received. This process continues with the next unit until the +model's forward pass is completely processed and each GPU contains its +own final outputs. As a simplification, you can imagine that each GPU +first collects parameters for the first layer, does the first layer's +forward, discards the parameters again, does the same process for the +second layer and so on until it obtains the final layer's outputs. +Remember that each GPU does this process with distinct data, so the +outputs will be different on each GPU. + +When taking gradients (`loss.backward()`), a similar scheme is +executed in reverse: we collect the parameters for the very final unit +on each GPU, calculate the local gradients for the final unit, and +again discard the additional parameters. We then also want to average +those local gradients across all processes to obtain shared global +per-unit gradients, so that – just like in DDP – each update step will +be the same across all processes. After averaging, the gradients are +sharded so that each GPU only receives the gradients that it requires +to update its own shard of the unit that is currently being processed. + +### FSDP considerations + +All the [DDP considerations](#ddp-considerations) apply. + +Additionally, working with checkpoints becomes even more bothersome. +Since each GPU contains its own model shard and the full model would +potentially not fit even on the CPU, we simply save each GPU's own +shards in a separate checkpoint file. The same goes for loading the +checkpoint: each GPU loads its own checkpoint file for its model +shard. This also means that we have to execute saving and loading on +every process, since the data is fully distinct. + +The example also contains an unused `save_model_singular` function +that gathers the full model on the CPU and then saves it in a single +checkpoint file which can then be loaded in a single process. Keep in +mind that this way of checkpointing is slower and limited by CPU +memory. -- GitLab