diff --git a/fixed_torch_run.py b/fixed_torch_run.py new file mode 100644 index 0000000000000000000000000000000000000000..0dfcf9da58ee33d74ddd70fc52df549fa1ab0c75 --- /dev/null +++ b/fixed_torch_run.py @@ -0,0 +1,50 @@ +from argparse import ArgumentParser +import ipaddress +import runpy +import socket + +from torch.distributed.elastic.agent.server import api as sapi + + +def parse_host(): + parser = ArgumentParser() + parser.add_argument('--rdzv_endpoint') + endpoint = parser.parse_known_args()[0].rdzv_endpoint + host = ( + endpoint.split(':', 1)[0] + if endpoint + else None + ) + return host + + +def fix_torch_run(host): + _orig_get_fq_hostname = sapi._get_fq_hostname + + if host: + try: + ipaddress.ip_address(host) + is_ip = True + except ValueError: + is_ip = False + + if is_ip: + def new_get_fq_hostname(): + return socket.gethostbyaddr(host)[0] + else: + def new_get_fq_hostname(): + return socket.getfqdn(host) + else: + new_get_fq_hostname = _orig_get_fq_hostname + + sapi._get_fq_hostname = new_get_fq_hostname + + +def main(): + host = parse_host() + fix_torch_run(host) + runpy.run_module('torch.distributed.run', run_name='__main__') + + +if __name__ == '__main__': + main() diff --git a/run_scripts/set_up.bash b/run_scripts/set_up.bash index fdb7f7ca963afda7691c68584444475c02b4607f..d7454ce3899ae37272189717f12764cf9bf05c0c 100644 --- a/run_scripts/set_up.bash +++ b/run_scripts/set_up.bash @@ -43,6 +43,10 @@ python -m pip install --upgrade pip cd "$MEGATRON_DEEPSPEED_REPO" ((DO_PULL)) && git stash && git pull --rebase origin main && git stash pop || : + +# Link repaired start script to code repo. +ln -sf "$OLDPWD"/../fixed_torch_run.py ./fixed_torch_run.py + git am "$OLDPWD"/../0001-Build-fused-kernels-in-temporary-directory.patch \ || ( git am --abort; diff --git a/run_scripts/tr11-176B-ml_juwels_pipe.sbatch b/run_scripts/tr11-176B-ml_juwels_pipe.sbatch new file mode 100644 index 0000000000000000000000000000000000000000..16216ae342b30d408198e6aa2041d5a11a38b9e2 --- /dev/null +++ b/run_scripts/tr11-176B-ml_juwels_pipe.sbatch @@ -0,0 +1,249 @@ +#!/bin/bash +#SBATCH --job-name=tr11-176B-ml +#SBATCH --nodes=96 +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=48 # number of cores per tasks +#SBATCH --hint=nomultithread # we get physical cores not logical +#SBATCH --gres=gpu:4 # number of gpus +#SBATCH --time 00:10:00 # maximum execution time (HH:MM:SS) +#SBATCH --output=%x-%j.out # output file name +#SBATCH --account=opengptx-elm +# Use `develbooster` for debugging, `booster` for "normal" jobs, and +# `largebooster` for jobs on more than 256 nodes. +#SBATCH --partition=booster + +set -x -e + +echo "START TIME: $(date)" + +CLEAN_PREV_JIT_BUILD=0 + +if ! [ -e activate.bash ]; then + echo 'Please execute the sbatch script from the `run_scripts` directory.' + exit 1 +fi +source activate.bash || exit 1 + +variant=main + +# The following paths might already be set in long-running-session in StartLongRun.sh +[ "x$DATA_OUTPUT_PATH" = x ] && DATA_OUTPUT_PATH="$ROOT_OUTPUT_DIR"/output_dir/tr11-176B-ml/"$variant" +[ "x$CHECKPOINT_PATH" = x ] && CHECKPOINT_PATH=$DATA_OUTPUT_PATH/checkpoints +[ "x$TENSORBOARD_PATH" = x ] && TENSORBOARD_PATH=$DATA_OUTPUT_PATH/tensorboard +[ "x$CODECARBON_PATH" = x ] && CODECARBON_PATH=$DATA_OUTPUT_PATH/codecarbon +[ "x$LOGS_PATH" = x ] && LOGS_PATH=$DATA_OUTPUT_PATH/logs +mkdir -p "$LOGS_PATH" + +cd "$MEGATRON_DEEPSPEED_REPO" +rm -f megatron/fused_kernels/build/lock +((CLEAN_PREV_JIT_BUILD)) && rm -rf megatron/fused_kernels/{build,__pycache__} + +KILL_SWITCH_PATH="$MEGATRON_DEEPSPEED_REPO"/kill-switch-tr11-176B-exp1 + +# BIGSCIENCE_REPO=$six_ALL_CCFRWORK/code/tr11-176B-ml/bigscience +# TRAIN_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/data/train-splits.txt +# VALID_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/data/valid-splits.txt +# CATALOGUE_JSON_PATH=$BIGSCIENCE_REPO/data/catalogue/training_dataset_ratios_merged_nigercongo_v3.json +# LOAD_RATIOS_SCRIPT=$BIGSCIENCE_REPO/data/catalogue/load_ratios_meg_ds_format.py +# python $LOAD_RATIOS_SCRIPT --dataset-ratios-path $CATALOGUE_JSON_PATH --split train --output-meg-ds-ratio-file $TRAIN_DATA_PATH +# python $LOAD_RATIOS_SCRIPT --dataset-ratios-path $CATALOGUE_JSON_PATH --split valid --output-meg-ds-ratio-file $VALID_DATA_PATH + +# TOKENIZER_NAME_OR_PATH=bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles + +# defining the right environment variables +CACHE_DIR="$ROOT_OUTPUT_DIR/.cache" +mkdir -p "$CACHE_DIR" +export TRANSFORMERS_CACHE="$CACHE_DIR"/models +export HF_DATASETS_CACHE="$CACHE_DIR"/datasets +export HF_MODULES_CACHE="$CACHE_DIR"/modules +export HF_METRICS_CACHE="$CACHE_DIR"/metrics +export HF_DATASETS_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 + +# testing for potential faulty nodes +# srun --jobid $SLURM_JOB_ID bash -c 'python -c "import torch, socket; print(socket.gethostname(), torch.cuda.is_available())"' + +# so processes know who to talk to +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +# Allow communication over InfiniBand cells. +MASTER_ADDR="${MASTER_ADDR}i" +# Get IP for hostname. +MASTER_ADDR="$(nslookup "$MASTER_ADDR" | grep -oP '(?<=Address: ).*')" +MASTER_PORT=6000 + +GPUS_PER_NODE=4 +NNODES=$SLURM_NNODES + +# Could also use half as many nodes (48), half PP_SIZE (12), but +# double TP_SIZE (8). +TP_SIZE=4 +PP_SIZE=24 + +MICRO_BATCH_SIZE=2 # was MBS=1 till GBS=784 +# GLOBAL_BATCH_SIZE=2048 # 4.2M tokens. It is larger than the initial plan of 3.2M tokens to get higher throughput +GLOBAL_BATCH_SIZE=$(((NNODES * GPUS_PER_NODE / (PP_SIZE * TP_SIZE)) * MICRO_BATCH_SIZE)) + +NHIDDEN=14336 +NLAYERS=70 +NHEADS=112 +SEQ_LEN=2048 + +SAVE_INTERVAL=100 + +TRAIN_SAMPLES=220_000_000 # 450B tokens +LR_DECAY_SAMPLES=200_000_000 # Decay for the first 410B tokens then continue at fixed --min-lr +LR_WARMUP_SAMPLES=183_105 # 375M tokens + + +OPTIMIZER_ARGS=" \ + --optimizer adam \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --adam-eps 1e-8 \ + --lr 6e-5 \ + --min-lr 6e-6 \ + --lr-decay-style cosine \ + --lr-decay-samples $LR_DECAY_SAMPLES \ + --lr-warmup-samples $LR_WARMUP_SAMPLES \ + --clip-grad 1.0 \ + --weight-decay 1e-1 \ + " +# for 20h 1190, for 100h 5990 +# --exit-duration-in-mins 1190 \ +EXIT_OPTS=" \ + --exit-duration-in-mins 5990 \ + " + + # --tokenizer-type PretrainedFromHF \ + # --tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \ + + # --rampup-batch-size 192 16 9_765_625 \ +GPT_ARGS=" \ + --pp-partition-method 'type:transformer|embedding' \ + --num-layers $NLAYERS \ + --hidden-size $NHIDDEN \ + --num-attention-heads $NHEADS \ + --seq-length $SEQ_LEN \ + --max-position-embeddings $SEQ_LEN \ + --micro-batch-size $MICRO_BATCH_SIZE \ + --global-batch-size $GLOBAL_BATCH_SIZE \ + --train-samples $TRAIN_SAMPLES \ + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --tokenizer-type GPT2BPETokenizer \ + --init-method-std 0.0048 \ + --embed-layernorm \ + --sync-tp-duplicated-parameters \ + --bf16 \ + --seed 42 \ + --position-embedding-type alibi \ + --checkpoint-activations \ + --abort-on-unmet-fused-kernel-constraints \ + --kill-switch-path $KILL_SWITCH_PATH \ + --pad-vocab-size-to 250880 \ + $OPTIMIZER_ARGS \ + $EXIT_OPTS \ + " + +# TODO: decide on efficient eval-interval + eval-iters + +OUTPUT_ARGS=" \ + --log-interval 1 \ + --save-interval $SAVE_INTERVAL \ + --eval-interval 1000 \ + --eval-iters 1 \ + --tensorboard-dir $TENSORBOARD_PATH \ + --tensorboard-queue-size 5 \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + " + +ZERO_STAGE=0 # important: bf16 must use z0! it implements its own zero stage 1 equivalent + +config_json="./ds_config.$SLURM_JOB_ID.json" + +# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size() +cat <<EOT > "$config_json" +{ + "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE, + "train_batch_size": $GLOBAL_BATCH_SIZE, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": $ZERO_STAGE + }, + "bf16": { + "enabled": true + }, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} +EOT + + +DEEPSPEED_ARGS=" \ + --deepspeed \ + --deepspeed_config ${config_json} \ + --zero-stage ${ZERO_STAGE} \ + --deepspeed-activation-checkpointing \ + " + +# export LAUNCHER="python -u -m torch.distributed.run \ +export LAUNCHER="python -u -m fixed_torch_run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --rdzv_conf=is_host=\$(if ((SLURM_NODEID)); then echo 0; else echo 1; fi) \ + --max_restarts 0 \ + --tee 3 \ + " + +# --universal-checkpoint \ +export CMD=" \ + $(pwd)/pretrain_gpt.py \ + --tensor-model-parallel-size $TP_SIZE \ + --pipeline-model-parallel-size $PP_SIZE \ + $GPT_ARGS \ + $OUTPUT_ARGS \ + --save $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --split 949,50,1 \ + --num-workers 2 \ + --valid-num-workers 0 \ + --data-impl mmap \ + --distributed-backend nccl \ + $DEEPSPEED_ARGS \ + " + +if [ "$LOAD_CHECKPOINTS" = true ] ; then + export CMD="$CMD\ + --load $CHECKPOINT_PATH \ + " +fi + +echo $CMD + +# do not remove or the training will hang and nodes will be lost w/o this workaround +export CUDA_LAUNCH_BLOCKING=1 + +# hide duplicated errors using this hack - will be properly fixed in pt-1.12 +export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 + +# handle timeouts +export NCCL_IB_TIMEOUT=20 + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID $CMD" 2>&1 | tee -a "$LOGS_PATH"/main_log.txt + +echo "END TIME: $(date)"