Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/scripts/ci_sgpu_jobs.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# sGPU job definitions for CI test suites.
# Format: <label> <logfile> <command> [args...]
# GPUs are assigned sequentially by run_parallel_sgpu.sh starting from --first-gpu.

examples examples.log .github/scripts/run_examples.sh
torch torch.log ci/pytorch.sh
jax jax.log ci/jax.sh
core core.log ci/core.sh
34 changes: 34 additions & 0 deletions .github/scripts/run_examples.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/bash
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
#
# Run TE examples on a single GPU.
# HIP_VISIBLE_DEVICES must be set by the caller (run_parallel_sgpu.sh).

set -e

# Autodetect repo root from this script's location (.github/scripts/ -> ../..)
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"

python -c "import os; print('HF_TOKEN set:', bool(os.environ.get('HF_TOKEN')))"

JAX_CONSTRAINTS=/tmp/jax-constraints.txt
pip freeze | grep -iE '^(jax|jaxlib|jax[_-]rocm|jax[_-]plugins)[=@]' > "$JAX_CONSTRAINTS" || true

cd "${REPO_ROOT}/examples/pytorch/mnist"
python main.py
python main.py --use-te
python main.py --use-fp8

cd "${REPO_ROOT}/examples/jax/mnist"
pip3 install -c "$JAX_CONSTRAINTS" -r requirements.txt
python test_single_gpu_mnist.py
python test_single_gpu_mnist.py --use-te
python test_single_gpu_mnist.py --use-fp8

cd "${REPO_ROOT}/examples/jax/encoder"
pip3 install -c "$JAX_CONSTRAINTS" -r requirements.txt
python test_single_gpu_encoder.py
python test_single_gpu_encoder.py --use-fp8
186 changes: 186 additions & 0 deletions .github/scripts/run_parallel_sgpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/bin/bash
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
#
# Run multiple sGPU jobs in parallel from one or more config files.
#
# Usage: run_parallel_sgpu.sh [-g|--first-gpu <n>] [-l|--log-dir <dir>] <config>...
#
# Config format (one job per line; # comments and blank lines are ignored):
# <label> <logfile> <command> [args...]
#
# GPUs are assigned sequentially starting from --first-gpu across all config
# entries in the order configs are passed on the command line.
#
# Each job's exit code is written to <log-dir>/<logfile>.rc by the job runner.

# Resolve repo root relative to this script's location (.github/scripts/ -> ../..)
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"

# How often (seconds) to poll child processes
: "${POLL_INTERVAL:=5}"
# Warn if a log file has not been updated in this many seconds
: "${STALL_WARN_SECS:=180}"
# Number of tail lines to show for stalled logs
: "${STALL_TAIL_LINES:=3}"
# Number of new lines to show for resumed logs
: "${STALL_RESUME_CONTEXT_LINES:=2}"

# Associative arrays: _JOB_PIDS[name]=pid _JOB_LOGS[pid]=logfile
declare -A _JOB_PIDS
declare -A _JOB_LOGS
_OVERALL_RC=0

# Launch a background job and register it.
# Usage: launch_job <name> <logfile> <cmd> [args...]
launch_job() {
local name="$1"
local logfile="$2"
local rcfile="${logfile}.rc"
shift 2
rm -f "$rcfile"
"$@" >"$logfile" 2>&1 &
local pid=$!
_JOB_PIDS["$name"]=$pid
_JOB_LOGS[$pid]="$logfile"
echo "Started '${name}' (pid ${pid}) -> ${logfile}"
}

# Wait for all currently registered jobs, polling every POLL_INTERVAL seconds.
# Writes <logfile>.rc for every finished job; warns about stalled logs.
# Clears _JOB_PIDS/_JOB_LOGS when done.
wait_for_jobs() {
local -A remaining
local -A stall_mtime
local -A stall_lineno
for name in "${!_JOB_PIDS[@]}"; do
remaining["$name"]="${_JOB_PIDS[$name]}"
done

while [ ${#remaining[@]} -gt 0 ]; do
sleep "$POLL_INTERVAL"

for name in "${!remaining[@]}"; do
local pid="${remaining[$name]}"

if ! kill -0 "$pid" 2>/dev/null; then
# Process has exited — capture its return code
wait "$pid"
local rc=$?
echo "[$(date '+%Y-%m-%d %H:%M:%S')] '${name}' (pid ${pid}) finished with rc=${rc}"
if [ $rc -ne 0 ]; then
_OVERALL_RC=$rc
fi
echo "$rc" > "${_JOB_LOGS[$pid]}.rc"
unset "remaining[$name]"
else
# Process still running — check for log staleness
local logfile="${_JOB_LOGS[$pid]}"
if [ -f "$logfile" ]; then
local now mtime age
now=$(date +%s)
mtime=$(stat -c '%Y' "$logfile" 2>/dev/null || echo "$now")
age=$(( now - mtime ))
if [ -n "${stall_mtime[$pid]+set}" ]; then
if [ "$mtime" -gt "${stall_mtime[$pid]}" ]; then
local frozen_secs=$(( mtime - stall_mtime[$pid] ))
local freeze_line="${stall_lineno[$pid]}"
echo "[$(date '+%Y-%m-%d %H:%M:%S')] INFO: '${name}' (pid ${pid}) log '${logfile}' resumed updating after ${frozen_secs}s"
echo "--- first ${STALL_RESUME_CONTEXT_LINES} lines of ${logfile} starting ${freeze_line} ---"
tail -n "+${freeze_line}" "$logfile" | head -n ${STALL_RESUME_CONTEXT_LINES}
echo "---"
unset "stall_mtime[$pid]"
unset "stall_lineno[$pid]"
fi
# else: still stalled but already warned — do nothing
elif [ "$age" -ge "$STALL_WARN_SECS" ]; then
# don't use wc here because it doces not count the last line if it doesn't end with a newline
local freeze_line=$(grep -c '' < "$logfile")
stall_mtime[$pid]=$mtime
stall_lineno[$pid]=$freeze_line
echo "[$(date '+%Y-%m-%d %H:%M:%S')] WARNING: '${name}' (pid ${pid}) log '${logfile}' has not been updated for ${age}s"
echo "--- last ${STALL_TAIL_LINES} lines of ${logfile} up to ${freeze_line} ---"
head -n "${freeze_line}" "$logfile" | tail -n "$STALL_TAIL_LINES"
echo "--- end of ${logfile} ---"
fi
fi
fi
done
done

# Reset for next batch
unset _JOB_PIDS
unset _JOB_LOGS
declare -gA _JOB_PIDS
declare -gA _JOB_LOGS
}

FIRST_GPU=${TEST_FIRST_GPU:-0}
LOG_DIR=${LOG_DIR:-/tmp/te_ci_logs}

# ---------------------------------------------------------------------------
# Parse arguments
while [[ $# -gt 0 ]]; do
case "$1" in
-g|--first-gpu)
FIRST_GPU="$2"; shift 2 ;;
--first-gpu=*)
FIRST_GPU="${1#*=}"; shift ;;
-l|--log-dir)
LOG_DIR="$2"; shift 2 ;;
--log-dir=*)
LOG_DIR="${1#*=}"; shift ;;
-*)
echo "Unknown option: $1" >&2
echo "Usage: $0 [-g|--first-gpu <n>] [-l|--log-dir <dir>] <config>..." >&2
exit 1 ;;
*)
break ;;
esac
done

if [[ $# -eq 0 ]]; then
echo "Error: at least one config file is required." >&2
echo "Usage: $0 [-g|--first-gpu <n>] [-l|--log-dir <dir>] <config>..." >&2
exit 1
fi

# Resolve config paths and LOG_DIR to absolute against the caller's CWD
# *before* changing directory, so relative paths passed by the caller remain valid.
resolved_configs=()
for c in "$@"; do
resolved_configs+=( "$(realpath -m "$c")" )
done
[[ "$LOG_DIR" != /* ]] && LOG_DIR="$(realpath -m "$LOG_DIR")"

mkdir -p "$LOG_DIR"

# cd to repo root so that commands in configs (e.g. ci/pytorch.sh) resolve correctly
cd "$REPO_ROOT" || { echo "Error: cannot cd to '${REPO_ROOT}'" >&2; exit 1; }

# ---------------------------------------------------------------------------
# Launch all jobs, assigning GPUs sequentially across all config files
gpu=$FIRST_GPU
for config in "${resolved_configs[@]}"; do
while IFS= read -r line || [[ -n "$line" ]]; do
# Skip blank lines and comments
[[ "$line" =~ ^[[:space:]]*# ]] && continue
[[ -z "${line//[[:space:]]/}" ]] && continue
read -r label logfile rest <<< "$line"
# Each suite appends its own subdir to the inherited base prefix so the base path is defined once.
if [ -n "${JUNITXML_PREFIX}${JUNITXML_SUFFIX}" ]; then
junitxml_dir="${JUNITXML_PREFIX}${label}/"
mkdir -p "${junitxml_dir}"
else
junitxml_dir=""
fi
# shellcheck disable=SC2086 # $rest is intentionally word-split
HIP_VISIBLE_DEVICES=$gpu JUNITXML_PREFIX="${junitxml_dir}" launch_job "$label" "${LOG_DIR}/$logfile" $rest
(( gpu++ ))
done < "$config"
done

wait_for_jobs
exit $_OVERALL_RC
127 changes: 22 additions & 105 deletions .github/workflows/rocm-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ jobs:
sgpu_tests:
name: sGPU Tests (${{ matrix.arch_label }})
needs: [select_image, build]
timeout-minutes: 360
timeout-minutes: 270
runs-on: ${{ matrix.arch_label == 'mi30x' && 'linux-te-mi30x-4' || 'linux-te-mi35x-4' }}
strategy:
fail-fast: false
Expand Down Expand Up @@ -184,11 +184,10 @@ jobs:
id: run-tests
# Below the job's timeout-minutes so an overrun kills only this step;
# the `if: always()` report + upload steps still run (artifacts survive).
timeout-minutes: 180
timeout-minutes: 240
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
rm -f FAIL_*
# One subdir per suite so PyTorch/JAX/Core get independent reports and
# never collide on identically-named XML (e.g. test_sanity_import.auto.xml
# exists in both the torch and jax suites and runs in parallel).
Expand All @@ -200,91 +199,10 @@ jobs:
-e JUNITXML_PREFIX=/workspace/test-results/ \
-e JUNITXML_SUFFIX=.xml \
-e HF_TOKEN="$HF_TOKEN" \
te-runner bash -c "$(cat <<'EOF'
#!/usr/bin/bash
set -x -o pipefail
ulimit -c 0 # Disable core dumps

# Each suite appends its own subdir to the inherited base prefix so the
# base path is defined once in the docker-exec env above (inline VAR=val
# applies only to that backgrounded subprocess; the sourced _utils.sh
# reads it from the env).
HIP_VISIBLE_DEVICES=0 JUNITXML_PREFIX=${JUNITXML_PREFIX}torch/ ci/pytorch.sh > /workspace/torch.log 2>&1 &
TORCH_PID=$!

HIP_VISIBLE_DEVICES=1 JUNITXML_PREFIX=${JUNITXML_PREFIX}jax/ ci/jax.sh > /workspace/jax.log 2>&1 &
JAX_PID=$!

(
set -e
python -c "import os; print('HF_TOKEN set:', bool(os.environ.get('HF_TOKEN')))"

JAX_CONSTRAINTS=/tmp/jax-constraints.txt
pip freeze | grep -iE '^(jax|jaxlib|jax[_-]rocm|jax[_-]plugins)[=@]' > "$JAX_CONSTRAINTS" || true

export HIP_VISIBLE_DEVICES=2

cd /workspace/examples/pytorch/mnist
python main.py
python main.py --use-te
python main.py --use-fp8

cd /workspace/examples/jax/mnist
pip3 install -c "$JAX_CONSTRAINTS" -r requirements.txt
python test_single_gpu_mnist.py
python test_single_gpu_mnist.py --use-te
python test_single_gpu_mnist.py --use-fp8

cd /workspace/examples/jax/encoder
pip3 install -c "$JAX_CONSTRAINTS" -r requirements.txt
python test_single_gpu_encoder.py
python test_single_gpu_encoder.py --use-fp8
) > /workspace/examples.log 2>&1 &
EXAMPLES_PID=$!

HIP_VISIBLE_DEVICES=3 JUNITXML_PREFIX=${JUNITXML_PREFIX}core/ ci/core.sh > /workspace/core.log 2>&1 &
CORE_PID=$!

wait $TORCH_PID; torch_rc=$?
wait $JAX_PID; jax_rc=$?
wait $EXAMPLES_PID; examples_rc=$?
wait $CORE_PID; core_rc=$?

if [ $torch_rc -ne 0 ]; then
echo "::group::[FAILED] PyTorch Log"
cat /workspace/torch.log
echo "::endgroup::"
echo "::error::PyTorch tests FAILED."
touch /workspace/FAIL_TORCH
fi

if [ $jax_rc -ne 0 ]; then
echo "::group::[FAILED] JAX Log"
cat /workspace/jax.log
echo "::endgroup::"
echo "::error::JAX tests FAILED."
touch /workspace/FAIL_JAX
fi

if [ $examples_rc -ne 0 ]; then
echo "::group::[FAILED] Examples Log"
cat /workspace/examples.log
echo "::endgroup::"
echo "::error::Examples FAILED."
touch /workspace/FAIL_EXAMPLES
fi

if [ $core_rc -ne 0 ]; then
echo "::group::[FAILED] Core Log"
cat /workspace/core.log
echo "::endgroup::"
echo "::error::Core tests FAILED."
touch /workspace/FAIL_CORE
fi

test $torch_rc -eq 0 -a $jax_rc -eq 0 -a $examples_rc -eq 0 -a $core_rc -eq 0
EOF
)"
te-runner bash /workspace/.github/scripts/run_parallel_sgpu.sh \
--first-gpu 0 \
--log-dir /workspace/ \
/workspace/.github/scripts/ci_sgpu_jobs.conf

- name: Generate test report
if: always()
Expand All @@ -304,22 +222,21 @@ jobs:
if: always()
run: |
EXIT_STATUS=0
if [[ -f FAIL_TORCH ]]; then
echo "::error::PyTorch tests failed."
EXIT_STATUS=1
fi
if [[ -f FAIL_JAX ]]; then
echo "::error::JAX tests failed."
EXIT_STATUS=1
fi
if [[ -f FAIL_EXAMPLES ]]; then
echo "::error::Examples failed."
EXIT_STATUS=1
fi
if [[ -f FAIL_CORE ]]; then
echo "::error::Core tests failed."
EXIT_STATUS=1
fi
for config in .github/scripts/ci_sgpu_jobs.conf; do
while IFS= read -r line || [[ -n "$line" ]]; do
[[ "$line" =~ ^[[:space:]]*# ]] && continue
[[ -z "${line//[[:space:]]/}" ]] && continue
read -r label logfile rest <<< "$line"
rc=$(cat "${logfile}.rc" 2>/dev/null || echo "missing")
if [[ "$rc" != "0" ]]; then
echo "::group::[FAILED] ${label} Log (rc=${rc})"
cat "$logfile" 2>/dev/null || echo "(log not found)"
echo "::endgroup::"
echo "::error::${label} tests FAILED."
EXIT_STATUS=1
fi
done < "$config"
done
exit $EXIT_STATUS

- name: Upload logs
Expand All @@ -340,7 +257,7 @@ jobs:
mgpu_tests:
name: mGPU ${{ matrix.framework == 'pytorch' && 'Torch' || 'JAX' }} (${{ matrix.arch_label }})
needs: [select_image, build]
timeout-minutes: 360
timeout-minutes: 210
runs-on: ${{ matrix.arch_label == 'mi30x' && 'linux-te-mi30x-8' || 'linux-te-mi35x-8' }}
strategy:
fail-fast: false
Expand Down
Loading
Loading