PyTorch at JSC
This recipe provides two examples to run parallel PyTorch trainings of
a ResNet-50 on fake image data.
The examples provide trainings in (1) data-parallel fashion using
DistributedDataParallel
(DDP),
which is useful for models that fit on one GPU, or in
(2) data-model-parallel fashion using
FullyShardedDataParallel
(FSDP),
which is useful for models that do not fit on one GPU.
For more information about DDP vs. FSDP, there is also this great article about when either of the two results in faster training: https://medium.com/pytorch/pytorch-data-parallel-best-practices-on-google-cloud-6c8da2be180d
Table of Contents
General
We will sometimes use standard parallel computing lingo in this guide and assume a basic understanding of how processes execute in parallel. If you are unfamiliar with the terms "rank" or "world size" in the context of parallel computing, or how parallel code actually works in reality, what follows is a brief description. Since this section is deliberately brief, consider checking out other tutorials such as one of our courses for more detailed descriptions and examples.
Before we get to the descriptions, always keep in mind that we start
independent processes that do not actually know about each other. Each
process just happily does its own thing until we – at one point – call
a function that actually communicates across the processes (like
torch.distributed.all_reduce
, which sums up a tensor across all the
processes by default). All processes evaluate the same code at roughly
the same time (we are never exactly parallel or completely
synchronized due to nature being noisy) but in different, distinct
Python processes, so when we say that a process is "evaluating" the
code, we really just mean that one of the processes reached this code
at some point in time, and now this process is doing something with
the code, but we have no idea about what's going on in the other
processes. We do assign numbers to our processes so that we can at
least figure out which process is evaluating the code, which makes
up one of the main paradigms we need to make use of in parallel code.
- world size: the number of processes across all nodes
- rank: the index of the evaluating process, goes from 0 to world size - 1
- local world size: the number of processes per node
- local rank: the index of the evaluating process considering only processes on the same node, goes from 0 to local world size - 1 on each node. This means that once you are on more than one node, multiple processes will have the same local ranks if they are on different nodes. In our case, we always start as many processes per node as we have GPUs on them, so the local rank can also be considered the per-node index of the GPU that the evaluating process should use.
Again, since this is a very brief explanation, consider looking into other tutorials such as one of our courses for more information if this wasn't helpful to you.
Parallel code
The most important part about parallel code in general is managing
output correctly. If you start to write to the same file with multiple
processes at the same time, it will become corrupted; the file system
does not know that we are writing to it in parallel. Similarly, if we
print our training progress, we do not want to do it from each process
individually; otherwise we would get the same output as many times as
we have processes. Thankfully, there is a very simple way to handle
this: whenever we create a directory, a file, write something, we make
sure that it is only done on one process, by checking whether we are
on a certain process. Usually process 0 (the first process) is chosen
for this because you will always have a process 0, even if you are not
running multiple processes in parallel. You can use
torch.distributed.get_rank() == 0
to check whether the process
evaluating the code is process 0, and then you can for example write
model checkpoints or print progress if – and only if – the evaluating
process is, in fact, process 0.
We use the local rank (which PyTorch sets as the environment variable
LOCAL_RANK
) to select a different GPU on each node to place our
model and data on. Because we are not using a higher-level library, we
need to do this manually.
Parallel data processing
To make use of all the CPUs on the system, we can enable parallel data
processing by passing the --train-num-workers
and
--valid-num-workers
arguments. These just set the according
num_workers
argument for the respective
DataLoader
s.
It can be a good idea to use all available CPUs except one, which is
kept for the main process.
For example: in our example sbatch
script for JUWELS Booster, we
select 12 CPUs per task. We can then pass --train-num-workers=8 --valid-num-workers=3
. With 8 + 3 = 11 workers in total, we leave one
CPU for our main process.
Launching the processes
PyTorch needs us to specify an endpoint, which is used to initialize
its distributed system. One part of this endpoint is the master
address (MASTER_ADDR
), which is an IP or hostname that all processes
have to connect to for initialization. We can thankfully obtain the
hostname of our job's first allocated node using Slurm relatively
easily. However, on JSC systems, we should not use this hostname as it
is because of awkward naming with regard to InfiniBand network
interfaces. If we used this hostname, nodes that are "too far apart"
would not be able to talk to each other. The simple fix is to append
an "i" after the hostname to use the correct network interface. Since
this special case only affects some JSC system, we query the machine
name and append the "i" only if necessary in the example.
As if all of the above wasn't enough, PyTorch also won't be able to do
communication via Gloo, the distributed backend it uses for another
initialization step, unless we set the environment variable
GLOO_SOCKET_IFNAME=ib0
.
The MASTER_PORT
part of the endpoint can be a relatively arbitrary
port number.
PyTorch offers the torchrun
API for launching parallel processes.
Sadly, there are some issues with this API that make its usage on JSC
systems difficult because of the aforementioned special hostname
handling (see here for more
information).
Thankfully, there are options to fix these issues:
- Use wrappers, such as
torchrun_jsc
. It modifies the underlying code on-the-fly to fix the issues.torchrun_jsc
/python -m torchrun_jsc
is a drop-in replacement fortorchrun
/python -m torch.distributed.run
and can be installed viapip
:python -m pip install torchrun_jsc
. - Use PyTorch as provided by the module system. We include patches to
ensure that the errors in
torchrun
are fixed and that it reliably works on our system.
In our example, we always use the wrapper even if we are already using
the module system to show off how to use it. This way, you can apply
the same template to your own projects that may use pip
-installed
PyTorch versions, or even a container. This also means that you need
to set up a virtual environment with torchrun_jsc
installed before
being able to use the example out-of-the-box. This can be done by
executing bash set_up.sh
once on a login node.
Job submission
As a reminder, before being able to submit a job, you have to manually
create an environment by executing bash set_up.sh
once on a login
node.
The sbatch
scripts are written so that they take arguments like a
usual script. To launch a job with different arguments, you can just
pass your desired arguments to the Python script to sbatch
, like so:
sbatch run.sbatch --train-num-workers=8 --valid-num-workers=3
Warnings upon PyTorch Distributed initialization
You can safely ignore warnings like the following:
[W socket.cpp:436] [c10d] The server socket cannot be initialized on [::]:54123 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:663] [c10d] The client socket cannot be initialized to connect to [jwb0001i.juwels]:54123 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:663] [c10d] The client socket cannot be initialized to connect to [jwb0001i.juwels]:32164 (errno: 97 - Address family not supported by protocol).
We have not noticed performance degradations or errors once PyTorch started to emit these warnings.
File system problems
Due to file system limits on the number of inodes ("number of files") and the amount of memory available to us, we can run into issues.
Cache directories
PyTorch and the libraries it uses like to save compiled code, downloaded models, or downloaded datasets to cache directories. By default, most of these point to your home directory, quickly consuming the limited available space.
Ideally, you can soft-link these cache directories to a project's
SCRATCH directory, or set the corresponding environment variables in
your scripts. As a general recommendation, soft-linking the entire
~/.cache
directory to SCRATCH is a good idea because many programs,
including pip
, use it. This will also handle most of the cache
directories mentioned here.
Here are some of the variables concerning various libraries and their default values.
- PyTorch Hub:
TORCH_HOME="$HOME"/.cache/torch/hub
- PyTorch extensions:
TORCH_EXTENSIONS_DIR="$HOME"/.cache/torch_extensions
- Triton (PyTorch dependency):
TRITON_CACHE_DIR="$HOME"/.triton/cache
- HuggingFace:
HF_HOME="$HOME"/.cache/huggingface
venv
directories
The venv
s we create can contain very many small files, or very large
binary blobs of compiled code. Both of these can lead to us reaching
file system limits. To avoid these problems, set up your venv
s in
SCRATCH. The example scripts here do not follow this practice out of
simplicity, but please consider it in your own projects. Be mindful
that files in SCRATCH are deleted after 90 days of not being touched,
so make sure that the environment is reproducible (e.g., by saving an
up-to-date modules.sh
and requirements.txt
in PROJECT).
GPU kernel compilation
Sometimes, additional specifications are required to build GPU kernels.
GPU architecture selection
Some libraries may require you to explicitly specify the compute architecture of your GPU for them to successfully build.
This can be done by setting the environment variable
TORCH_CUDA_ARCH_LIST
. It can be used to specify a list of CUDA
compute capabilities of the target GPU architectures that kernels will
be built for. Entries can be separated using semicola. Compute
capabilities for various CUDA-enabled GPUs can be found on this page:
https://developer.nvidia.com/cuda-gpus
For example, if we want to build kernels that target NVIDIA V100 (compute capability 7.0), NVIDIA A100 (compute capability 8.0), and NVIDIA H100 (compute capability 9.0) GPUs, we would set:
export TORCH_CUDA_ARCH_LIST=7.0;8.0;9.0
Kernels not being compiled
You may also find that some Python packages do not build GPU kernels
by default even if TORCH_CUDA_ARCH_LIST
is specified. This can
happen if kernels are only built when a GPU is actually found on the
system setting up the environment. Since we are building the
environment on a login node, we won't have a GPU available. But
actually, we are still able to compile kernels as the kernel compiler
is available, and that is all we require. Usually, libraries offer
an escape hatch via environment variables so you can still force GPU
kernel compilation manually. If they are not documented, you can try
to look for such escape hatches in the package's setup.py
. Maybe an
AI chatbot can be helpful in finding these.
PyTorch Lightning
If you are using PyTorch Lightning, you should launch jobs
differently. Instead of using torchrun
, you can just launch the
Python script directly. This change also requires you to set #SBATCH --ntasks-per-node=4
. Again, because instead of torchrun
launching
the processes, we want to let Slurm do it and then let PyTorch
Lightning handle initialization between the processes all by itself.
In the following code snippets, remember this is just an
example/template for PyTorch Lightning usage. You need to write your
own main.py
with PyTorch Lightning in mind. Here are the relevant
changes from the sbatch
script if you want to start a parallel job
using PyTorch Lightning:
Instead of
#SBATCH --ntasks-per-node=1
use
#SBATCH --ntasks-per-node=4
Instead of
srun env -u CUDA_VISIBLE_DEVICES python -u -m torchrun_jsc \
--nproc_per_node=gpu \
--nnodes="$SLURM_JOB_NUM_NODES" \
--rdzv_id="$SLURM_JOB_ID" \
--rdzv_endpoint="$MASTER_ADDR":"$MASTER_PORT" \
--rdzv_backend=c10d \
"$curr_dir"/main.py "$@"
use
srun env -u CUDA_VISIBLE_DEVICES python -u "$curr_dir"/main.py "$@"
Additionally, if using PyTorch Lightning, you may encounter issues
when running jobs on many nodes. This is because of the aforementioned
hostname issue. If you export the MASTER_ADDR
environment variable
like in the example scripts (that means including the "i"), PyTorch
Lightning will be able to pick it up. However, this is only
implemented for PyTorch Lightning ≥2.1. Previous versions require
modifying the underlying PyTorch Lightning code, which can be achieved
by putting the following code at the start of your main script:
import os
try:
from lightning.pytorch.plugins.environments import SLURMEnvironment
except (ModuleNotFoundError, ImportError):
# For PyTorch Lightning <2, this namespace needs to used instead.
from pytorch_lightning.plugins.environments import SLURMEnvironment
def patch_lightning_slurm_master_addr():
# Do not patch anything if we're not on a Jülich machine.
if os.getenv('SYSTEMNAME', '') not in [
'juwelsbooster',
'juwels',
'jurecadc',
'jusuf',
]:
return
old_resolver = SLURMEnvironment.resolve_root_node_address
def new_resolver(*args):
nodes = args[-1]
# Append an "i" for communication over InfiniBand.
return old_resolver(nodes) + 'i'
SLURMEnvironment.__old_resolve_root_node_address = old_resolver
SLURMEnvironment.resolve_root_node_address = new_resolver
patch_lightning_slurm_master_addr()
Advanced PyTorch Distributed debugging
To enable logging for the Python parts of PyTorch Distributed, please check out the official documentation. At the time of writing this recipe, there is no easy way to enable logging for all PyTorch Distributed modules. Instead, you have to enable it individually, e.g.:
export TORCH_LOGS='+torch.distributed.elastic.agent.server.api,+torch.distributed.elastic.agent.server.local_elastic_agent,+torch.distributed.elastic.rendezvous.dynamic_rendezvous,+torch.distributed.elastic.rendezvous.c10d_rendezvous_backend,+torch.distributed.elastic.rendezvous.utils,+torch.distributed.distributed_c10d'
If you are facing very advanced problems and want to enable debug logging for PyTorch Distributed C++ parts, you have to set the following environment variables:
export TORCH_DISTRIBUTED_DEBUG=DETAIL
export TORCH_CPP_LOG_LEVEL=INFO
DDP
This example will use
DistributedDataParallel
(DDP)
to do data-parallel training. This means that we will evaluate the
same model on different batches on each GPU. We copy the model at the
start of training to all GPUs so we have the same initial setup on
each process. Whenever we take gradients (loss.backward()
), the
gradients are averaged across all processes, so that each update step
will be the same across all processes. This way, our model never
diverges between processes; the model is always ensured to stay the
same on all GPUs.
DDP considerations
The initial copy of the model to all processes and the reduction of
gradients is all taken care of thanks to the DDP wrapper. However, if
we only added the DDP wrapper, our model would train on the same data
on each process; we wouldn't actually gain anything from the
parallelization. So in addition, we have to split our data so that
each process sees its own distinct subset. This is also called
"sharding" the data, because each process obtains its own shard of the
full dataset. Thankfully, PyTorch offers the
DistributedSampler
,
which we can apply to our
DataLoader
s
to achieve the aforementioned sharding of data. Note that for each
training epoch, we have to explicitly set the epoch on the sampler so
that data is shuffled differently on each epoch.
Another thing to keep in mind: while our model is the same on each process, the data it gets evaluated on and, accordingly, the loss, will be different on each process. To get more comparable metrics, we average the loss over all processes before printing it in the example.
Finally, when using DDP, the batch size changes depending on the number of processes we use. That is because we only configure the "local batch size", i.e., the batch size per process. The "global batch size" is obtained by multiplying the local batch size times the number of processes. If we scale up the number of processes, we obtain a larger batch size; this, in turn, this will change what learning rate we should use. A very simple heuristic is to just scale the base learning rate you would use for the local batch size proportional to the number of processes: for example, we just multiply the base learning rate times the number of processes. This is automatically done in the code so that it "just works" with a large number of processes, but ideally you would tune the learning rate manually for the global batch size you use.
FSDP
Currently missing.