diff --git a/fixed_torch_run.py b/fixed_torch_run.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dfcf9da58ee33d74ddd70fc52df549fa1ab0c75
--- /dev/null
+++ b/fixed_torch_run.py
@@ -0,0 +1,50 @@
+from argparse import ArgumentParser
+import ipaddress
+import runpy
+import socket
+
+from torch.distributed.elastic.agent.server import api as sapi
+
+
+def parse_host():
+    parser = ArgumentParser()
+    parser.add_argument('--rdzv_endpoint')
+    endpoint = parser.parse_known_args()[0].rdzv_endpoint
+    host = (
+        endpoint.split(':', 1)[0]
+        if endpoint
+        else None
+    )
+    return host
+
+
+def fix_torch_run(host):
+    _orig_get_fq_hostname = sapi._get_fq_hostname
+
+    if host:
+        try:
+            ipaddress.ip_address(host)
+            is_ip = True
+        except ValueError:
+            is_ip = False
+
+        if is_ip:
+            def new_get_fq_hostname():
+                return socket.gethostbyaddr(host)[0]
+        else:
+            def new_get_fq_hostname():
+                return socket.getfqdn(host)
+    else:
+        new_get_fq_hostname = _orig_get_fq_hostname
+
+    sapi._get_fq_hostname = new_get_fq_hostname
+
+
+def main():
+    host = parse_host()
+    fix_torch_run(host)
+    runpy.run_module('torch.distributed.run', run_name='__main__')
+
+
+if __name__ == '__main__':
+    main()
diff --git a/run_scripts/set_up.bash b/run_scripts/set_up.bash
index fdb7f7ca963afda7691c68584444475c02b4607f..d7454ce3899ae37272189717f12764cf9bf05c0c 100644
--- a/run_scripts/set_up.bash
+++ b/run_scripts/set_up.bash
@@ -43,6 +43,10 @@ python -m pip install --upgrade pip
 
 cd "$MEGATRON_DEEPSPEED_REPO"
 ((DO_PULL)) && git stash && git pull --rebase origin main && git stash pop || :
+
+# Link repaired start script to code repo.
+ln -sf "$OLDPWD"/../fixed_torch_run.py ./fixed_torch_run.py
+
 git am "$OLDPWD"/../0001-Build-fused-kernels-in-temporary-directory.patch \
     || (
     git am --abort;
diff --git a/run_scripts/tr11-176B-ml_juwels_pipe.sbatch b/run_scripts/tr11-176B-ml_juwels_pipe.sbatch
new file mode 100644
index 0000000000000000000000000000000000000000..16216ae342b30d408198e6aa2041d5a11a38b9e2
--- /dev/null
+++ b/run_scripts/tr11-176B-ml_juwels_pipe.sbatch
@@ -0,0 +1,249 @@
+#!/bin/bash
+#SBATCH --job-name=tr11-176B-ml
+#SBATCH --nodes=96
+#SBATCH --ntasks-per-node=1          # crucial - only 1 task per dist per node!
+#SBATCH --cpus-per-task=48           # number of cores per tasks
+#SBATCH --hint=nomultithread         # we get physical cores not logical
+#SBATCH --gres=gpu:4                 # number of gpus
+#SBATCH --time 00:10:00             # maximum execution time (HH:MM:SS)
+#SBATCH --output=%x-%j.out           # output file name
+#SBATCH --account=opengptx-elm
+# Use `develbooster` for debugging, `booster` for "normal" jobs, and
+# `largebooster` for jobs on more than 256 nodes.
+#SBATCH --partition=booster
+
+set -x -e
+
+echo "START TIME: $(date)"
+
+CLEAN_PREV_JIT_BUILD=0
+
+if ! [ -e activate.bash ]; then
+    echo 'Please execute the sbatch script from the `run_scripts` directory.'
+    exit 1
+fi
+source activate.bash || exit 1
+
+variant=main
+
+# The following paths might already be set in long-running-session in  StartLongRun.sh
+[ "x$DATA_OUTPUT_PATH" = x ] &&  DATA_OUTPUT_PATH="$ROOT_OUTPUT_DIR"/output_dir/tr11-176B-ml/"$variant"
+[ "x$CHECKPOINT_PATH" = x ] && CHECKPOINT_PATH=$DATA_OUTPUT_PATH/checkpoints
+[ "x$TENSORBOARD_PATH" = x ] &&  TENSORBOARD_PATH=$DATA_OUTPUT_PATH/tensorboard
+[ "x$CODECARBON_PATH" = x ] &&  CODECARBON_PATH=$DATA_OUTPUT_PATH/codecarbon
+[ "x$LOGS_PATH" = x ] &&  LOGS_PATH=$DATA_OUTPUT_PATH/logs
+mkdir -p "$LOGS_PATH"
+
+cd "$MEGATRON_DEEPSPEED_REPO"
+rm -f megatron/fused_kernels/build/lock
+((CLEAN_PREV_JIT_BUILD)) && rm -rf megatron/fused_kernels/{build,__pycache__}
+
+KILL_SWITCH_PATH="$MEGATRON_DEEPSPEED_REPO"/kill-switch-tr11-176B-exp1
+
+# BIGSCIENCE_REPO=$six_ALL_CCFRWORK/code/tr11-176B-ml/bigscience
+# TRAIN_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/data/train-splits.txt
+# VALID_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/data/valid-splits.txt
+# CATALOGUE_JSON_PATH=$BIGSCIENCE_REPO/data/catalogue/training_dataset_ratios_merged_nigercongo_v3.json
+# LOAD_RATIOS_SCRIPT=$BIGSCIENCE_REPO/data/catalogue/load_ratios_meg_ds_format.py
+# python $LOAD_RATIOS_SCRIPT --dataset-ratios-path $CATALOGUE_JSON_PATH --split train --output-meg-ds-ratio-file $TRAIN_DATA_PATH
+# python $LOAD_RATIOS_SCRIPT --dataset-ratios-path $CATALOGUE_JSON_PATH --split valid --output-meg-ds-ratio-file $VALID_DATA_PATH
+
+# TOKENIZER_NAME_OR_PATH=bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles
+
+# defining the right environment variables
+CACHE_DIR="$ROOT_OUTPUT_DIR/.cache"
+mkdir -p "$CACHE_DIR"
+export TRANSFORMERS_CACHE="$CACHE_DIR"/models
+export HF_DATASETS_CACHE="$CACHE_DIR"/datasets
+export HF_MODULES_CACHE="$CACHE_DIR"/modules
+export HF_METRICS_CACHE="$CACHE_DIR"/metrics
+export HF_DATASETS_OFFLINE=1
+export TRANSFORMERS_OFFLINE=1
+
+# testing for potential faulty nodes
+# srun --jobid $SLURM_JOB_ID bash -c 'python -c "import torch, socket; print(socket.gethostname(), torch.cuda.is_available())"'
+
+# so processes know who to talk to
+MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)"
+# Allow communication over InfiniBand cells.
+MASTER_ADDR="${MASTER_ADDR}i"
+# Get IP for hostname.
+MASTER_ADDR="$(nslookup "$MASTER_ADDR" | grep -oP '(?<=Address: ).*')"
+MASTER_PORT=6000
+
+GPUS_PER_NODE=4
+NNODES=$SLURM_NNODES
+
+# Could also use half as many nodes (48), half PP_SIZE (12), but
+# double TP_SIZE (8).
+TP_SIZE=4
+PP_SIZE=24
+
+MICRO_BATCH_SIZE=2  # was MBS=1 till GBS=784
+# GLOBAL_BATCH_SIZE=2048  # 4.2M tokens. It is larger than the initial plan of 3.2M tokens to get higher throughput
+GLOBAL_BATCH_SIZE=$(((NNODES * GPUS_PER_NODE / (PP_SIZE * TP_SIZE)) * MICRO_BATCH_SIZE))
+
+NHIDDEN=14336
+NLAYERS=70
+NHEADS=112
+SEQ_LEN=2048
+
+SAVE_INTERVAL=100
+
+TRAIN_SAMPLES=220_000_000  # 450B tokens
+LR_DECAY_SAMPLES=200_000_000  # Decay for the first 410B tokens then continue at fixed --min-lr
+LR_WARMUP_SAMPLES=183_105  # 375M tokens
+
+
+OPTIMIZER_ARGS=" \
+    --optimizer adam \
+    --adam-beta1 0.9 \
+    --adam-beta2 0.95 \
+    --adam-eps 1e-8 \
+    --lr 6e-5 \
+    --min-lr 6e-6 \
+    --lr-decay-style cosine \
+    --lr-decay-samples $LR_DECAY_SAMPLES \
+    --lr-warmup-samples $LR_WARMUP_SAMPLES \
+    --clip-grad 1.0 \
+    --weight-decay 1e-1 \
+    "
+# for 20h 1190, for 100h 5990
+#    --exit-duration-in-mins 1190 \
+EXIT_OPTS=" \
+    --exit-duration-in-mins 5990 \
+    "
+
+    # --tokenizer-type PretrainedFromHF \
+    # --tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \
+
+    # --rampup-batch-size 192 16 9_765_625 \
+GPT_ARGS=" \
+    --pp-partition-method 'type:transformer|embedding' \
+    --num-layers $NLAYERS \
+    --hidden-size $NHIDDEN \
+    --num-attention-heads $NHEADS \
+    --seq-length $SEQ_LEN \
+    --max-position-embeddings $SEQ_LEN \
+    --micro-batch-size $MICRO_BATCH_SIZE \
+    --global-batch-size $GLOBAL_BATCH_SIZE \
+    --train-samples $TRAIN_SAMPLES \
+    --vocab-file $VOCAB_FILE \
+    --merge-file $MERGE_FILE \
+    --tokenizer-type GPT2BPETokenizer \
+    --init-method-std 0.0048 \
+    --embed-layernorm \
+    --sync-tp-duplicated-parameters \
+    --bf16 \
+    --seed 42 \
+    --position-embedding-type alibi \
+    --checkpoint-activations \
+    --abort-on-unmet-fused-kernel-constraints \
+    --kill-switch-path $KILL_SWITCH_PATH \
+    --pad-vocab-size-to 250880 \
+    $OPTIMIZER_ARGS \
+    $EXIT_OPTS \
+    "
+
+# TODO: decide on efficient eval-interval + eval-iters
+
+OUTPUT_ARGS=" \
+    --log-interval 1 \
+    --save-interval $SAVE_INTERVAL \
+    --eval-interval 1000 \
+    --eval-iters 1 \
+    --tensorboard-dir $TENSORBOARD_PATH \
+    --tensorboard-queue-size 5 \
+    --log-timers-to-tensorboard \
+    --log-batch-size-to-tensorboard \
+    --log-validation-ppl-to-tensorboard \
+    "
+
+ZERO_STAGE=0 # important: bf16 must use z0! it implements its own zero stage 1 equivalent
+
+config_json="./ds_config.$SLURM_JOB_ID.json"
+
+# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size()
+cat <<EOT > "$config_json"
+{
+  "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
+  "train_batch_size": $GLOBAL_BATCH_SIZE,
+  "gradient_clipping": 1.0,
+  "zero_optimization": {
+    "stage": $ZERO_STAGE
+  },
+  "bf16": {
+    "enabled": true
+  },
+  "steps_per_print": 2000,
+  "wall_clock_breakdown": false
+}
+EOT
+
+
+DEEPSPEED_ARGS=" \
+    --deepspeed \
+    --deepspeed_config ${config_json} \
+    --zero-stage ${ZERO_STAGE} \
+    --deepspeed-activation-checkpointing \
+    "
+
+# export LAUNCHER="python -u -m torch.distributed.run \
+export LAUNCHER="python -u -m fixed_torch_run \
+    --nproc_per_node $GPUS_PER_NODE \
+    --nnodes $NNODES \
+    --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
+    --rdzv_backend c10d \
+    --rdzv_conf=is_host=\$(if ((SLURM_NODEID)); then echo 0; else echo 1; fi) \
+    --max_restarts 0 \
+    --tee 3 \
+    "
+
+#    --universal-checkpoint \
+export CMD=" \
+    $(pwd)/pretrain_gpt.py \
+    --tensor-model-parallel-size $TP_SIZE \
+    --pipeline-model-parallel-size $PP_SIZE \
+    $GPT_ARGS \
+    $OUTPUT_ARGS \
+    --save $CHECKPOINT_PATH \
+    --data-path $DATA_PATH \
+    --split 949,50,1 \
+    --num-workers 2 \
+    --valid-num-workers 0 \
+    --data-impl mmap \
+    --distributed-backend nccl \
+     $DEEPSPEED_ARGS \
+    "
+
+if [ "$LOAD_CHECKPOINTS" = true ] ; then
+    export CMD="$CMD\
+        --load $CHECKPOINT_PATH \
+        "
+fi
+
+echo $CMD
+
+# do not remove or the training will hang and nodes will be lost w/o this workaround
+export CUDA_LAUNCH_BLOCKING=1
+
+# hide duplicated errors using this hack - will be properly fixed in pt-1.12
+export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json
+
+# force crashing on nccl issues like hanging broadcast
+export NCCL_ASYNC_ERROR_HANDLING=1
+
+# handle timeouts
+export NCCL_IB_TIMEOUT=20
+
+# srun error handling:
+# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
+# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
+SRUN_ARGS=" \
+    --wait=60 \
+    --kill-on-bad-exit=1 \
+    "
+
+clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID $CMD" 2>&1 | tee -a "$LOGS_PATH"/main_log.txt
+
+echo "END TIME: $(date)"