Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PyTorch at JSC
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Simulation and Data Lab Applied Machine Learning
PyTorch at JSC
Commits
7274481b
Commit
7274481b
authored
11 months ago
by
Jan Ebert
Browse files
Options
Downloads
Patches
Plain Diff
Explain FSDP
parent
e72c8976
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
README.md
+60
-1
60 additions, 1 deletion
README.md
with
60 additions
and
1 deletion
README.md
+
60
−
1
View file @
7274481b
...
...
@@ -35,6 +35,7 @@ https://medium.com/pytorch/pytorch-data-parallel-best-practices-on-google-cloud-
-
[
DDP
](
#ddp
)
-
[
DDP considerations
](
#ddp-considerations
)
-
[
FSDP
](
#fsdp
)
-
[
FSDP considerations
](
#fsdp-considerations
)
## General
...
...
@@ -455,4 +456,62 @@ manually for the global batch size you use.
## FSDP
Currently missing.
This example will use
[
`FullyShardedDataParallel`
](
https://pytorch.org/docs/stable/fsdp.html
)
(FSDP)
to do model- and data-parallel training. This means that we will store
different parts of the model on different GPUs and also evaluate it on
different batches on each GPU. Similarly, the gradients and optimizer
states are also distributed.
For initialization, FSDP first defines a hierarchy of distinct, but
possibly nested, submodules ("units") for the model. This process is
also called "wrapping" in FSDP terminology and can be controlled using
the
`auto_wrap_policy`
argument to
`FullyShardedDataParallel`
. The
parameters in each unit are then split and distributed ("sharded", or
scattered) to all GPUs. In the end, each GPU contains its own,
distinct model shard.
Whenever we do a forward pass with the model, we sequentially pass
through units in the following way: FSDP automatically collects the
parameters in the unit that is currently being processed across all
GPUs, so that each GPU contains the full unit that is being processed.
Remember that in addition, the GPU also contains other parts of the
model from other units, but in the sharded form. Each GPU can then
execute the unit's forward pass on its own input batch to receive its
own local activations. The GPU then discards the additional parameters
that it received. This process continues with the next unit until the
model's forward pass is completely processed and each GPU contains its
own final outputs. As a simplification, you can imagine that each GPU
first collects parameters for the first layer, does the first layer's
forward, discards the parameters again, does the same process for the
second layer and so on until it obtains the final layer's outputs.
Remember that each GPU does this process with distinct data, so the
outputs will be different on each GPU.
When taking gradients (
`loss.backward()`
), a similar scheme is
executed in reverse: we collect the parameters for the very final unit
on each GPU, calculate the local gradients for the final unit, and
again discard the additional parameters. We then also want to average
those local gradients across all processes to obtain shared global
per-unit gradients, so that – just like in DDP – each update step will
be the same across all processes. After averaging, the gradients are
sharded so that each GPU only receives the gradients that it requires
to update its own shard of the unit that is currently being processed.
### FSDP considerations
All the
[
DDP considerations
](
#ddp-considerations
)
apply.
Additionally, working with checkpoints becomes even more bothersome.
Since each GPU contains its own model shard and the full model would
potentially not fit even on the CPU, we simply save each GPU's own
shards in a separate checkpoint file. The same goes for loading the
checkpoint: each GPU loads its own checkpoint file for its model
shard. This also means that we have to execute saving and loading on
every process, since the data is fully distinct.
The example also contains an unused
`save_model_singular`
function
that gathers the full model on the CPU and then saves it in a single
checkpoint file which can then be loaded in a single process. Keep in
mind that this way of checkpointing is slower and limited by CPU
memory.
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment