Skip to content
Snippets Groups Projects
Commit 167e0078 authored by Jan Ebert's avatar Jan Ebert
Browse files

Add 176B training script

To upgrade an already setup environment:
```bash
cd run_scripts
source variables.bash
ln -s ../fixed_torch_run.py "$MEGATRON_DEEPSPEED_REPO"
```
parent 0a8561dc
Branches
No related tags found
No related merge requests found
from argparse import ArgumentParser
import ipaddress
import runpy
import socket
from torch.distributed.elastic.agent.server import api as sapi
def parse_host():
parser = ArgumentParser()
parser.add_argument('--rdzv_endpoint')
endpoint = parser.parse_known_args()[0].rdzv_endpoint
host = (
endpoint.split(':', 1)[0]
if endpoint
else None
)
return host
def fix_torch_run(host):
_orig_get_fq_hostname = sapi._get_fq_hostname
if host:
try:
ipaddress.ip_address(host)
is_ip = True
except ValueError:
is_ip = False
if is_ip:
def new_get_fq_hostname():
return socket.gethostbyaddr(host)[0]
else:
def new_get_fq_hostname():
return socket.getfqdn(host)
else:
new_get_fq_hostname = _orig_get_fq_hostname
sapi._get_fq_hostname = new_get_fq_hostname
def main():
host = parse_host()
fix_torch_run(host)
runpy.run_module('torch.distributed.run', run_name='__main__')
if __name__ == '__main__':
main()
...@@ -43,6 +43,10 @@ python -m pip install --upgrade pip ...@@ -43,6 +43,10 @@ python -m pip install --upgrade pip
cd "$MEGATRON_DEEPSPEED_REPO" cd "$MEGATRON_DEEPSPEED_REPO"
((DO_PULL)) && git stash && git pull --rebase origin main && git stash pop || : ((DO_PULL)) && git stash && git pull --rebase origin main && git stash pop || :
# Link repaired start script to code repo.
ln -sf "$OLDPWD"/../fixed_torch_run.py ./fixed_torch_run.py
git am "$OLDPWD"/../0001-Build-fused-kernels-in-temporary-directory.patch \ git am "$OLDPWD"/../0001-Build-fused-kernels-in-temporary-directory.patch \
|| ( || (
git am --abort; git am --abort;
......
#!/bin/bash
#SBATCH --job-name=tr11-176B-ml
#SBATCH --nodes=96
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=48 # number of cores per tasks
#SBATCH --hint=nomultithread # we get physical cores not logical
#SBATCH --gres=gpu:4 # number of gpus
#SBATCH --time 00:10:00 # maximum execution time (HH:MM:SS)
#SBATCH --output=%x-%j.out # output file name
#SBATCH --account=opengptx-elm
# Use `develbooster` for debugging, `booster` for "normal" jobs, and
# `largebooster` for jobs on more than 256 nodes.
#SBATCH --partition=booster
set -x -e
echo "START TIME: $(date)"
CLEAN_PREV_JIT_BUILD=0
if ! [ -e activate.bash ]; then
echo 'Please execute the sbatch script from the `run_scripts` directory.'
exit 1
fi
source activate.bash || exit 1
variant=main
# The following paths might already be set in long-running-session in StartLongRun.sh
[ "x$DATA_OUTPUT_PATH" = x ] && DATA_OUTPUT_PATH="$ROOT_OUTPUT_DIR"/output_dir/tr11-176B-ml/"$variant"
[ "x$CHECKPOINT_PATH" = x ] && CHECKPOINT_PATH=$DATA_OUTPUT_PATH/checkpoints
[ "x$TENSORBOARD_PATH" = x ] && TENSORBOARD_PATH=$DATA_OUTPUT_PATH/tensorboard
[ "x$CODECARBON_PATH" = x ] && CODECARBON_PATH=$DATA_OUTPUT_PATH/codecarbon
[ "x$LOGS_PATH" = x ] && LOGS_PATH=$DATA_OUTPUT_PATH/logs
mkdir -p "$LOGS_PATH"
cd "$MEGATRON_DEEPSPEED_REPO"
rm -f megatron/fused_kernels/build/lock
((CLEAN_PREV_JIT_BUILD)) && rm -rf megatron/fused_kernels/{build,__pycache__}
KILL_SWITCH_PATH="$MEGATRON_DEEPSPEED_REPO"/kill-switch-tr11-176B-exp1
# BIGSCIENCE_REPO=$six_ALL_CCFRWORK/code/tr11-176B-ml/bigscience
# TRAIN_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/data/train-splits.txt
# VALID_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/data/valid-splits.txt
# CATALOGUE_JSON_PATH=$BIGSCIENCE_REPO/data/catalogue/training_dataset_ratios_merged_nigercongo_v3.json
# LOAD_RATIOS_SCRIPT=$BIGSCIENCE_REPO/data/catalogue/load_ratios_meg_ds_format.py
# python $LOAD_RATIOS_SCRIPT --dataset-ratios-path $CATALOGUE_JSON_PATH --split train --output-meg-ds-ratio-file $TRAIN_DATA_PATH
# python $LOAD_RATIOS_SCRIPT --dataset-ratios-path $CATALOGUE_JSON_PATH --split valid --output-meg-ds-ratio-file $VALID_DATA_PATH
# TOKENIZER_NAME_OR_PATH=bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles
# defining the right environment variables
CACHE_DIR="$ROOT_OUTPUT_DIR/.cache"
mkdir -p "$CACHE_DIR"
export TRANSFORMERS_CACHE="$CACHE_DIR"/models
export HF_DATASETS_CACHE="$CACHE_DIR"/datasets
export HF_MODULES_CACHE="$CACHE_DIR"/modules
export HF_METRICS_CACHE="$CACHE_DIR"/metrics
export HF_DATASETS_OFFLINE=1
export TRANSFORMERS_OFFLINE=1
# testing for potential faulty nodes
# srun --jobid $SLURM_JOB_ID bash -c 'python -c "import torch, socket; print(socket.gethostname(), torch.cuda.is_available())"'
# so processes know who to talk to
MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)"
# Allow communication over InfiniBand cells.
MASTER_ADDR="${MASTER_ADDR}i"
# Get IP for hostname.
MASTER_ADDR="$(nslookup "$MASTER_ADDR" | grep -oP '(?<=Address: ).*')"
MASTER_PORT=6000
GPUS_PER_NODE=4
NNODES=$SLURM_NNODES
# Could also use half as many nodes (48), half PP_SIZE (12), but
# double TP_SIZE (8).
TP_SIZE=4
PP_SIZE=24
MICRO_BATCH_SIZE=2 # was MBS=1 till GBS=784
# GLOBAL_BATCH_SIZE=2048 # 4.2M tokens. It is larger than the initial plan of 3.2M tokens to get higher throughput
GLOBAL_BATCH_SIZE=$(((NNODES * GPUS_PER_NODE / (PP_SIZE * TP_SIZE)) * MICRO_BATCH_SIZE))
NHIDDEN=14336
NLAYERS=70
NHEADS=112
SEQ_LEN=2048
SAVE_INTERVAL=100
TRAIN_SAMPLES=220_000_000 # 450B tokens
LR_DECAY_SAMPLES=200_000_000 # Decay for the first 410B tokens then continue at fixed --min-lr
LR_WARMUP_SAMPLES=183_105 # 375M tokens
OPTIMIZER_ARGS=" \
--optimizer adam \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--adam-eps 1e-8 \
--lr 6e-5 \
--min-lr 6e-6 \
--lr-decay-style cosine \
--lr-decay-samples $LR_DECAY_SAMPLES \
--lr-warmup-samples $LR_WARMUP_SAMPLES \
--clip-grad 1.0 \
--weight-decay 1e-1 \
"
# for 20h 1190, for 100h 5990
# --exit-duration-in-mins 1190 \
EXIT_OPTS=" \
--exit-duration-in-mins 5990 \
"
# --tokenizer-type PretrainedFromHF \
# --tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \
# --rampup-batch-size 192 16 9_765_625 \
GPT_ARGS=" \
--pp-partition-method 'type:transformer|embedding' \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--num-attention-heads $NHEADS \
--seq-length $SEQ_LEN \
--max-position-embeddings $SEQ_LEN \
--micro-batch-size $MICRO_BATCH_SIZE \
--global-batch-size $GLOBAL_BATCH_SIZE \
--train-samples $TRAIN_SAMPLES \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--tokenizer-type GPT2BPETokenizer \
--init-method-std 0.0048 \
--embed-layernorm \
--sync-tp-duplicated-parameters \
--bf16 \
--seed 42 \
--position-embedding-type alibi \
--checkpoint-activations \
--abort-on-unmet-fused-kernel-constraints \
--kill-switch-path $KILL_SWITCH_PATH \
--pad-vocab-size-to 250880 \
$OPTIMIZER_ARGS \
$EXIT_OPTS \
"
# TODO: decide on efficient eval-interval + eval-iters
OUTPUT_ARGS=" \
--log-interval 1 \
--save-interval $SAVE_INTERVAL \
--eval-interval 1000 \
--eval-iters 1 \
--tensorboard-dir $TENSORBOARD_PATH \
--tensorboard-queue-size 5 \
--log-timers-to-tensorboard \
--log-batch-size-to-tensorboard \
--log-validation-ppl-to-tensorboard \
"
ZERO_STAGE=0 # important: bf16 must use z0! it implements its own zero stage 1 equivalent
config_json="./ds_config.$SLURM_JOB_ID.json"
# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size()
cat <<EOT > "$config_json"
{
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
"train_batch_size": $GLOBAL_BATCH_SIZE,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": $ZERO_STAGE
},
"bf16": {
"enabled": true
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
EOT
DEEPSPEED_ARGS=" \
--deepspeed \
--deepspeed_config ${config_json} \
--zero-stage ${ZERO_STAGE} \
--deepspeed-activation-checkpointing \
"
# export LAUNCHER="python -u -m torch.distributed.run \
export LAUNCHER="python -u -m fixed_torch_run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
--rdzv_backend c10d \
--rdzv_conf=is_host=\$(if ((SLURM_NODEID)); then echo 0; else echo 1; fi) \
--max_restarts 0 \
--tee 3 \
"
# --universal-checkpoint \
export CMD=" \
$(pwd)/pretrain_gpt.py \
--tensor-model-parallel-size $TP_SIZE \
--pipeline-model-parallel-size $PP_SIZE \
$GPT_ARGS \
$OUTPUT_ARGS \
--save $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--split 949,50,1 \
--num-workers 2 \
--valid-num-workers 0 \
--data-impl mmap \
--distributed-backend nccl \
$DEEPSPEED_ARGS \
"
if [ "$LOAD_CHECKPOINTS" = true ] ; then
export CMD="$CMD\
--load $CHECKPOINT_PATH \
"
fi
echo $CMD
# do not remove or the training will hang and nodes will be lost w/o this workaround
export CUDA_LAUNCH_BLOCKING=1
# hide duplicated errors using this hack - will be properly fixed in pt-1.12
export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# handle timeouts
export NCCL_IB_TIMEOUT=20
# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
"
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID $CMD" 2>&1 | tee -a "$LOGS_PATH"/main_log.txt
echo "END TIME: $(date)"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment