Skip to content
Snippets Groups Projects
Commit 167e0078 authored by Jan Ebert's avatar Jan Ebert
Browse files

Add 176B training script

To upgrade an already setup environment:
```bash
cd run_scripts
source variables.bash
ln -s ../fixed_torch_run.py "$MEGATRON_DEEPSPEED_REPO"
```
parent 0a8561dc
No related branches found
No related tags found
No related merge requests found
from argparse import ArgumentParser
import ipaddress
import runpy
import socket
from torch.distributed.elastic.agent.server import api as sapi
def parse_host():
parser = ArgumentParser()
parser.add_argument('--rdzv_endpoint')
endpoint = parser.parse_known_args()[0].rdzv_endpoint
host = (
endpoint.split(':', 1)[0]
if endpoint
else None
)
return host
def fix_torch_run(host):
_orig_get_fq_hostname = sapi._get_fq_hostname
if host:
try:
ipaddress.ip_address(host)
is_ip = True
except ValueError:
is_ip = False
if is_ip:
def new_get_fq_hostname():
return socket.gethostbyaddr(host)[0]
else:
def new_get_fq_hostname():
return socket.getfqdn(host)
else:
new_get_fq_hostname = _orig_get_fq_hostname
sapi._get_fq_hostname = new_get_fq_hostname
def main():
host = parse_host()
fix_torch_run(host)
runpy.run_module('torch.distributed.run', run_name='__main__')
if __name__ == '__main__':
main()
......@@ -43,6 +43,10 @@ python -m pip install --upgrade pip
cd "$MEGATRON_DEEPSPEED_REPO"
((DO_PULL)) && git stash && git pull --rebase origin main && git stash pop || :
# Link repaired start script to code repo.
ln -sf "$OLDPWD"/../fixed_torch_run.py ./fixed_torch_run.py
git am "$OLDPWD"/../0001-Build-fused-kernels-in-temporary-directory.patch \
|| (
git am --abort;
......
#!/bin/bash
#SBATCH --job-name=tr11-176B-ml
#SBATCH --nodes=96
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=48 # number of cores per tasks
#SBATCH --hint=nomultithread # we get physical cores not logical
#SBATCH --gres=gpu:4 # number of gpus
#SBATCH --time 00:10:00 # maximum execution time (HH:MM:SS)
#SBATCH --output=%x-%j.out # output file name
#SBATCH --account=opengptx-elm
# Use `develbooster` for debugging, `booster` for "normal" jobs, and
# `largebooster` for jobs on more than 256 nodes.
#SBATCH --partition=booster
set -x -e
echo "START TIME: $(date)"
CLEAN_PREV_JIT_BUILD=0
if ! [ -e activate.bash ]; then
echo 'Please execute the sbatch script from the `run_scripts` directory.'
exit 1
fi
source activate.bash || exit 1
variant=main
# The following paths might already be set in long-running-session in StartLongRun.sh
[ "x$DATA_OUTPUT_PATH" = x ] && DATA_OUTPUT_PATH="$ROOT_OUTPUT_DIR"/output_dir/tr11-176B-ml/"$variant"
[ "x$CHECKPOINT_PATH" = x ] && CHECKPOINT_PATH=$DATA_OUTPUT_PATH/checkpoints
[ "x$TENSORBOARD_PATH" = x ] && TENSORBOARD_PATH=$DATA_OUTPUT_PATH/tensorboard
[ "x$CODECARBON_PATH" = x ] && CODECARBON_PATH=$DATA_OUTPUT_PATH/codecarbon
[ "x$LOGS_PATH" = x ] && LOGS_PATH=$DATA_OUTPUT_PATH/logs
mkdir -p "$LOGS_PATH"
cd "$MEGATRON_DEEPSPEED_REPO"
rm -f megatron/fused_kernels/build/lock
((CLEAN_PREV_JIT_BUILD)) && rm -rf megatron/fused_kernels/{build,__pycache__}
KILL_SWITCH_PATH="$MEGATRON_DEEPSPEED_REPO"/kill-switch-tr11-176B-exp1
# BIGSCIENCE_REPO=$six_ALL_CCFRWORK/code/tr11-176B-ml/bigscience
# TRAIN_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/data/train-splits.txt
# VALID_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/data/valid-splits.txt
# CATALOGUE_JSON_PATH=$BIGSCIENCE_REPO/data/catalogue/training_dataset_ratios_merged_nigercongo_v3.json
# LOAD_RATIOS_SCRIPT=$BIGSCIENCE_REPO/data/catalogue/load_ratios_meg_ds_format.py
# python $LOAD_RATIOS_SCRIPT --dataset-ratios-path $CATALOGUE_JSON_PATH --split train --output-meg-ds-ratio-file $TRAIN_DATA_PATH
# python $LOAD_RATIOS_SCRIPT --dataset-ratios-path $CATALOGUE_JSON_PATH --split valid --output-meg-ds-ratio-file $VALID_DATA_PATH
# TOKENIZER_NAME_OR_PATH=bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles
# defining the right environment variables
CACHE_DIR="$ROOT_OUTPUT_DIR/.cache"
mkdir -p "$CACHE_DIR"
export TRANSFORMERS_CACHE="$CACHE_DIR"/models
export HF_DATASETS_CACHE="$CACHE_DIR"/datasets
export HF_MODULES_CACHE="$CACHE_DIR"/modules
export HF_METRICS_CACHE="$CACHE_DIR"/metrics
export HF_DATASETS_OFFLINE=1
export TRANSFORMERS_OFFLINE=1
# testing for potential faulty nodes
# srun --jobid $SLURM_JOB_ID bash -c 'python -c "import torch, socket; print(socket.gethostname(), torch.cuda.is_available())"'
# so processes know who to talk to
MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)"
# Allow communication over InfiniBand cells.
MASTER_ADDR="${MASTER_ADDR}i"
# Get IP for hostname.
MASTER_ADDR="$(nslookup "$MASTER_ADDR" | grep -oP '(?<=Address: ).*')"
MASTER_PORT=6000
GPUS_PER_NODE=4
NNODES=$SLURM_NNODES
# Could also use half as many nodes (48), half PP_SIZE (12), but
# double TP_SIZE (8).
TP_SIZE=4
PP_SIZE=24
MICRO_BATCH_SIZE=2 # was MBS=1 till GBS=784
# GLOBAL_BATCH_SIZE=2048 # 4.2M tokens. It is larger than the initial plan of 3.2M tokens to get higher throughput
GLOBAL_BATCH_SIZE=$(((NNODES * GPUS_PER_NODE / (PP_SIZE * TP_SIZE)) * MICRO_BATCH_SIZE))
NHIDDEN=14336
NLAYERS=70
NHEADS=112
SEQ_LEN=2048
SAVE_INTERVAL=100
TRAIN_SAMPLES=220_000_000 # 450B tokens
LR_DECAY_SAMPLES=200_000_000 # Decay for the first 410B tokens then continue at fixed --min-lr
LR_WARMUP_SAMPLES=183_105 # 375M tokens
OPTIMIZER_ARGS=" \
--optimizer adam \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--adam-eps 1e-8 \
--lr 6e-5 \
--min-lr 6e-6 \
--lr-decay-style cosine \
--lr-decay-samples $LR_DECAY_SAMPLES \
--lr-warmup-samples $LR_WARMUP_SAMPLES \
--clip-grad 1.0 \
--weight-decay 1e-1 \
"
# for 20h 1190, for 100h 5990
# --exit-duration-in-mins 1190 \
EXIT_OPTS=" \
--exit-duration-in-mins 5990 \
"
# --tokenizer-type PretrainedFromHF \
# --tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \
# --rampup-batch-size 192 16 9_765_625 \
GPT_ARGS=" \
--pp-partition-method 'type:transformer|embedding' \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--num-attention-heads $NHEADS \
--seq-length $SEQ_LEN \
--max-position-embeddings $SEQ_LEN \
--micro-batch-size $MICRO_BATCH_SIZE \
--global-batch-size $GLOBAL_BATCH_SIZE \
--train-samples $TRAIN_SAMPLES \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--tokenizer-type GPT2BPETokenizer \
--init-method-std 0.0048 \
--embed-layernorm \
--sync-tp-duplicated-parameters \
--bf16 \
--seed 42 \
--position-embedding-type alibi \
--checkpoint-activations \
--abort-on-unmet-fused-kernel-constraints \
--kill-switch-path $KILL_SWITCH_PATH \
--pad-vocab-size-to 250880 \
$OPTIMIZER_ARGS \
$EXIT_OPTS \
"
# TODO: decide on efficient eval-interval + eval-iters
OUTPUT_ARGS=" \
--log-interval 1 \
--save-interval $SAVE_INTERVAL \
--eval-interval 1000 \
--eval-iters 1 \
--tensorboard-dir $TENSORBOARD_PATH \
--tensorboard-queue-size 5 \
--log-timers-to-tensorboard \
--log-batch-size-to-tensorboard \
--log-validation-ppl-to-tensorboard \
"
ZERO_STAGE=0 # important: bf16 must use z0! it implements its own zero stage 1 equivalent
config_json="./ds_config.$SLURM_JOB_ID.json"
# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size()
cat <<EOT > "$config_json"
{
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
"train_batch_size": $GLOBAL_BATCH_SIZE,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": $ZERO_STAGE
},
"bf16": {
"enabled": true
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
EOT
DEEPSPEED_ARGS=" \
--deepspeed \
--deepspeed_config ${config_json} \
--zero-stage ${ZERO_STAGE} \
--deepspeed-activation-checkpointing \
"
# export LAUNCHER="python -u -m torch.distributed.run \
export LAUNCHER="python -u -m fixed_torch_run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
--rdzv_backend c10d \
--rdzv_conf=is_host=\$(if ((SLURM_NODEID)); then echo 0; else echo 1; fi) \
--max_restarts 0 \
--tee 3 \
"
# --universal-checkpoint \
export CMD=" \
$(pwd)/pretrain_gpt.py \
--tensor-model-parallel-size $TP_SIZE \
--pipeline-model-parallel-size $PP_SIZE \
$GPT_ARGS \
$OUTPUT_ARGS \
--save $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--split 949,50,1 \
--num-workers 2 \
--valid-num-workers 0 \
--data-impl mmap \
--distributed-backend nccl \
$DEEPSPEED_ARGS \
"
if [ "$LOAD_CHECKPOINTS" = true ] ; then
export CMD="$CMD\
--load $CHECKPOINT_PATH \
"
fi
echo $CMD
# do not remove or the training will hang and nodes will be lost w/o this workaround
export CUDA_LAUNCH_BLOCKING=1
# hide duplicated errors using this hack - will be properly fixed in pt-1.12
export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# handle timeouts
export NCCL_IB_TIMEOUT=20
# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
"
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID $CMD" 2>&1 | tee -a "$LOGS_PATH"/main_log.txt
echo "END TIME: $(date)"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment