From 7274481bc4d0a525c88f77def75b1ec70939ebb4 Mon Sep 17 00:00:00 2001
From: janEbert <janpublicebert@posteo.net>
Date: Tue, 9 Jul 2024 17:21:19 +0200
Subject: [PATCH] Explain FSDP

---
 README.md | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 60 insertions(+), 1 deletion(-)

diff --git a/README.md b/README.md
index b7c7335..74a9ccf 100644
--- a/README.md
+++ b/README.md
@@ -35,6 +35,7 @@ https://medium.com/pytorch/pytorch-data-parallel-best-practices-on-google-cloud-
     - [DDP](#ddp)
         - [DDP considerations](#ddp-considerations)
     - [FSDP](#fsdp)
+        - [FSDP considerations](#fsdp-considerations)
 
 ## General
 
@@ -455,4 +456,62 @@ manually for the global batch size you use.
 
 ## FSDP
 
-Currently missing.
+This example will use
+[`FullyShardedDataParallel`](https://pytorch.org/docs/stable/fsdp.html) (FSDP)
+to do model- and data-parallel training. This means that we will store
+different parts of the model on different GPUs and also evaluate it on
+different batches on each GPU. Similarly, the gradients and optimizer
+states are also distributed.
+
+For initialization, FSDP first defines a hierarchy of distinct, but
+possibly nested, submodules ("units") for the model. This process is
+also called "wrapping" in FSDP terminology and can be controlled using
+the `auto_wrap_policy` argument to `FullyShardedDataParallel`. The
+parameters in each unit are then split and distributed ("sharded", or
+scattered) to all GPUs. In the end, each GPU contains its own,
+distinct model shard.
+
+Whenever we do a forward pass with the model, we sequentially pass
+through units in the following way: FSDP automatically collects the
+parameters in the unit that is currently being processed across all
+GPUs, so that each GPU contains the full unit that is being processed.
+Remember that in addition, the GPU also contains other parts of the
+model from other units, but in the sharded form. Each GPU can then
+execute the unit's forward pass on its own input batch to receive its
+own local activations. The GPU then discards the additional parameters
+that it received. This process continues with the next unit until the
+model's forward pass is completely processed and each GPU contains its
+own final outputs. As a simplification, you can imagine that each GPU
+first collects parameters for the first layer, does the first layer's
+forward, discards the parameters again, does the same process for the
+second layer and so on until it obtains the final layer's outputs.
+Remember that each GPU does this process with distinct data, so the
+outputs will be different on each GPU.
+
+When taking gradients (`loss.backward()`), a similar scheme is
+executed in reverse: we collect the parameters for the very final unit
+on each GPU, calculate the local gradients for the final unit, and
+again discard the additional parameters. We then also want to average
+those local gradients across all processes to obtain shared global
+per-unit gradients, so that – just like in DDP – each update step will
+be the same across all processes. After averaging, the gradients are
+sharded so that each GPU only receives the gradients that it requires
+to update its own shard of the unit that is currently being processed.
+
+### FSDP considerations
+
+All the [DDP considerations](#ddp-considerations) apply.
+
+Additionally, working with checkpoints becomes even more bothersome.
+Since each GPU contains its own model shard and the full model would
+potentially not fit even on the CPU, we simply save each GPU's own
+shards in a separate checkpoint file. The same goes for loading the
+checkpoint: each GPU loads its own checkpoint file for its model
+shard. This also means that we have to execute saving and loading on
+every process, since the data is fully distinct.
+
+The example also contains an unused `save_model_singular` function
+that gathers the full model on the CPU and then saves it in a single
+checkpoint file which can then be loaded in a single process. Keep in
+mind that this way of checkpointing is slower and limited by CPU
+memory.
-- 
GitLab