From 648fd483b9b79f3a6fe8c3ac010cb9aa38bb20cd Mon Sep 17 00:00:00 2001 From: ipanfilo <145064111+ipanfilo@users.noreply.github.com> Date: Mon, 13 Apr 2026 13:24:30 -0400 Subject: [PATCH 01/14] Fix TE loading w/o meta packages (#531) (cherry picked from commit 4297f0381f536394f9adb5cc7a64b2dc9470a9c1) --- transformer_engine/__init__.py | 4 +- transformer_engine/common/__init__.py | 117 ++++++++++++++++++-------- 2 files changed, 81 insertions(+), 40 deletions(-) diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index 71219deb1..b2b175892 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -86,8 +86,6 @@ try: __version__ = str(metadata.version("transformer_engine")) except metadata.PackageNotFoundError: - if not transformer_engine.common.te_rocm_build: - raise - _te_core_installed, _, __version__ = transformer_engine.common.get_te_core_package_info() + _te_core_installed, _, __version__ = transformer_engine.common.get_te_core_package_info(True) if not _te_core_installed: raise diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 662bc504c..9dbf998e5 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -133,14 +133,14 @@ def _get_shared_object_file(library: str) -> Path: ) -def get_te_core_package_info() -> Tuple[bool, str, str]: +def get_te_core_package_info(rocm:bool) -> Tuple[bool, str, str]: """ Check if Tranformer Engine core package is installed. Returns the module name and version if found. """ te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13") - if te_rocm_build: + if rocm: te_core_packages = ("transformer-engine-rocm7",) for package in te_core_packages: if _is_package_installed(package): @@ -162,14 +162,23 @@ def load_framework_extension(framework: str) -> None: module_name = f"transformer_engine_{framework}" # Name of the pip extra dependency for framework extensions from PyPI. - extra_dep_name = module_name + # ROCm: here is a bug in upstream code - using module name whereas it should be framework name. + extra_dep_name = framework if framework == "torch": extra_dep_name = "pytorch" + global te_rocm_build + package_name = module_name + if te_rocm_build: + package_name = f"transformer_engine_rocm_{framework}" + extra_dep_name = f"rocm_{extra_dep_name}" + # Find the TE packages. The core and framework packages can only be installed via PyPI. # For the `transformer-engine` package, we need to check explicity. - te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info() - te_framework_installed = _is_package_installed(module_name) + te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info( + te_rocm_build + ) + te_framework_installed = _is_package_installed(package_name) te_installed = _is_package_installed("transformer_engine") te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine") @@ -183,18 +192,27 @@ def load_framework_extension(framework: str) -> None: # If the framework extension pip package is installed, it means that TE is installed via # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework # extension are all installed via PyPI and have matching versions. + # Metapackage and core lib matching is checked in `sanity_checks_for_pypi_installation()`, + # so here we only need to check the framework extension. if te_framework_installed: assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." assert te_core_installed, ( "Could not find TE core package " f"`transformer-engine-{'rocm' if te_rocm_build else 'cu'}*`." ) - assert version(module_name) == version("transformer-engine") == te_core_version, ( + assert version(package_name) == te_core_version, ( "Transformer Engine package version mismatch. Found" - f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and {te_core_package_name}" + f" {package_name} v{version(package_name)} and {te_core_package_name}" f" v{te_core_version}. Install transformer-engine using " f"'pip3 install --no-build-isolation transformer-engine[{extra_dep_name}]==VERSION'" ) + else: + # If the framework extension package is not installed, it means that TE is either not + # installed via PyPI or it's an invalid installation with missing framework extension. + assert not te_installed_via_pypi, ( + f"Found `transformer-engine` PyPI package but not {package_name}." + " Install transformer-engine using " + f"'pip3 install --no-build-isolation transformer-engine[{extra_dep_name}]==VERSION'" + ) # After all checks are completed, load the shared object file. spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework)) @@ -206,14 +224,29 @@ def load_framework_extension(framework: str) -> None: def sanity_checks_for_pypi_installation() -> None: """Ensure that package is installed correctly if using PyPI.""" - te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info() + global te_rocm_build + te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info(True) + _nv_core_installed, _nv_core_package_name, _nv_core_version = get_te_core_package_info(False) + if te_core_installed: + te_rocm_build = True + assert not _nv_core_installed, ( + f"Multiple core packages found: {te_core_package_name} and {_nv_core_package_name}." + " Please uninstall all `transformer-engine*` packages and reinstall the correct one." + ) + elif _nv_core_installed: + te_rocm_build = False + te_core_installed, te_core_package_name, te_core_version = ( + _nv_core_installed, _nv_core_package_name, _nv_core_version + ) te_installed = _is_package_installed("transformer_engine") te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine") - assert te_installed, "Could not find `transformer-engine`." + # Meta package is optional for ROCm build. + if not te_rocm_build: + assert te_installed, "Could not find `transformer-engine`." # If the core package is installed via PyPI. - if te_core_installed: + if te_core_installed and te_installed: assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." assert version("transformer-engine") == te_core_version, ( "Transformer Engine package version mismatch. Found " @@ -225,7 +258,10 @@ def sanity_checks_for_pypi_installation() -> None: elif te_installed_via_pypi: raise RuntimeError( "Found empty `transformer-engine` meta package installed. " - "Install `transformer-engine` with framework extensions via" + "Install `transformer-engine` with framework extensions via " + "'pip3 install --no-build-isolation transformer-engine[rocm_pytorch,rocm_jax]'" + " or 'pip3 install transformer-engine[rocm]' for the ROCm TE core lib only." + " Or if you are using CUDA, install with " "'pip3 install --no-build-isolation transformer-engine[pytorch,jax]==VERSION'" " or 'pip3 install transformer-engine[core]` for the TE core lib only. The `core_cu12`" " or `core_cu13` extra deps can be used to specify CUDA version for the TE core lib." @@ -362,7 +398,7 @@ def _load_cuda_library(lib_name: str): raise RuntimeError(f"{lib_name} shared object not found.") -te_rocm_build = False +te_rocm_build = None @functools.cache def is_fp8_fnuz(): @@ -378,33 +414,40 @@ def _load_core_library(): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + sanity_checks_for_pypi_installation() + if not te_rocm_build: + try: + # `_load_cuda_library` is used for packages that must be loaded + # during runtime. Both system and pypi packages are searched + # and an error is thrown if not found. + _, _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn") + system_nvrtc, _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc") + system_curand, _CURAND_LIB_CTYPES = _load_cuda_library("curand") + + # This additional step is necessary to be able to install TE wheels + # and import TE (without any guards) in an environment where the cuda + # toolkit might be absent without being guarded + load_libs_for_no_ctk = not system_nvrtc and not system_curand + if load_libs_for_no_ctk: + _CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas", strict=True) + _CUDART_LIB_CTYPES = _load_cuda_library_from_python("cudart", strict=True) + _CUDNN_ALL_LIB_CTYPES = _load_cuda_library_from_python("cudnn", strict=True) + except (OSError, RuntimeError, subprocess.CalledProcessError): + pass + + _TE_LIB_CTYPES = _load_core_library() try: - sanity_checks_for_pypi_installation() - - # `_load_cuda_library` is used for packages that must be loaded - # during runtime. Both system and pypi packages are searched - # and an error is thrown if not found. - _, _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn") - system_nvrtc, _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc") - system_curand, _CURAND_LIB_CTYPES = _load_cuda_library("curand") - - # This additional step is necessary to be able to install TE wheels - # and import TE (without any guards) in an environment where the cuda - # toolkit might be absent without being guarded - load_libs_for_no_ctk = not system_nvrtc and not system_curand - if load_libs_for_no_ctk: - _CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas", strict=True) - _CUDART_LIB_CTYPES = _load_cuda_library_from_python("cudart", strict=True) - _CUDNN_ALL_LIB_CTYPES = _load_cuda_library_from_python("cudnn", strict=True) - except (OSError, RuntimeError, subprocess.CalledProcessError): - pass - finally: - _TE_LIB_CTYPES = _load_core_library() - try: - te_rocm_build = _TE_LIB_CTYPES.nvte_is_rocm_build() + _te_rocm_build = _TE_LIB_CTYPES.nvte_is_rocm_build() except AttributeError: # If the function is not available, we assume it's not a ROCm build - te_rocm_build = False + _te_rocm_build = False + if te_rocm_build is not None: + assert te_rocm_build == _te_rocm_build, ( + f"ROCm build mismatch. Detected ROCm installation: {te_rocm_build}," + f" but library API returns: {_te_rocm_build}." + ) + te_rocm_build = _te_rocm_build + if te_rocm_build: try: # Get installed ROCm version From 3588fbf3dad9aa8d52a59ac48b03deec1854cd8d Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Fri, 17 Apr 2026 10:10:38 -0500 Subject: [PATCH 02/14] [ROCm] fix the bug in hipfied optimized cast tranpose flow that two kernels got run (#545) * [ROCm] fix the bug in hipfied optimized cast tranpose flow that two kernels got run (cherry picked from commit a6470b08005d73e6d6611af1f683ccd4cf895f0b) --- ci/pytorch.sh | 1 + .../test_sanity_hipified_cast_transpose.py | 81 +++++++++++++++++++ .../common/transpose/cast_transpose.cu | 17 ++-- 3 files changed, 90 insertions(+), 9 deletions(-) create mode 100644 tests/pytorch/test_sanity_hipified_cast_transpose.py diff --git a/ci/pytorch.sh b/ci/pytorch.sh index d69009f1a..0d322f6c6 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -65,6 +65,7 @@ run_test_config(){ run_default_fa 1 test_recipe.py run 1 test_sanity.py run_default_fa 1 test_sanity_import.py + run_default_fa 3 test_sanity_hipified_cast_transpose.py run_default_fa 1 attention/test_attention.py # Backend selection is controlled by the test run_default_fa 1 attention/test_cp_utils.py run_default_fa 1 attention/test_kv_cache.py diff --git a/tests/pytorch/test_sanity_hipified_cast_transpose.py b/tests/pytorch/test_sanity_hipified_cast_transpose.py new file mode 100644 index 000000000..cedaa3c68 --- /dev/null +++ b/tests/pytorch/test_sanity_hipified_cast_transpose.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. +# License for AMD contributions = MIT. See LICENSE for more information + +""" +Non-regression tests for the hipified cast_transpose kernel dispatch. + +Verifies that tex.quantize (which routes through cast_transpose on AMD) +dispatches exactly one GPU kernel per call. A previous bug (commit 542b8c7b) +caused the NVTE_USE_OPTIMIZED_HIPIFIED_CAST_TRANSPOSE path to launch both +cast_transpose_optimized_kernel AND cast_transpose_general_kernel per call, +doubling the GPU work. +""" + +import os +import pytest +import torch +from torch.profiler import profile, ProfilerActivity + +from transformer_engine.pytorch import cpp_extensions as tex +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + + +def _fill_uniform(shape, dtype): + """Create a deterministic random tensor on GPU.""" + gen = torch.Generator(device="cuda") + gen.manual_seed(12345) + return torch.empty(shape, dtype=dtype, device="cuda").uniform_(-2.0, 2.0, generator=gen) + + +@pytest.mark.parametrize("shape", [ + (128, 128), + (2048, 12288), + (256, 256), +]) +@pytest.mark.parametrize("in_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("out_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("use_hipified_env", [False, True]) +def test_single_kernel_dispatch(shape, in_dtype, out_dtype, use_hipified_env): + """ + Verify that tex.quantize dispatches exactly one cast_transpose GPU kernel. + Tests both the default cost-model RTC path and the PR #89 hipified path + (NVTE_USE_OPTIMIZED_HIPIFIED_CAST_TRANSPOSE=1). + """ + input_tensor = _fill_uniform(shape, dtype=in_dtype) + scale = torch.rand(1, dtype=torch.float32, device="cuda") * 3.0 - 2.0 + amax = torch.zeros(1, dtype=torch.float32, device="cuda") + + env_key = "NVTE_USE_OPTIMIZED_HIPIFIED_CAST_TRANSPOSE" + old_val = os.environ.get(env_key) + + try: + if use_hipified_env: + os.environ[env_key] = "1" + else: + os.environ.pop(env_key, None) + + # Warmup (also triggers hipRTC compilation) + q = Float8Quantizer(scale=scale.clone(), amax=amax.clone(), fp8_dtype=out_dtype) + tex.quantize(input_tensor, q) + torch.cuda.synchronize() + + # Profiled run + q = Float8Quantizer(scale=scale.clone(), amax=amax.clone(), fp8_dtype=out_dtype) + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + tex.quantize(input_tensor, q) + torch.cuda.synchronize() + + ct_kernels = [ + evt.name for evt in prof.events() + if evt.device_type == torch.autograd.DeviceType.CUDA + and "cast_transpose" in evt.name + ] + + assert len(ct_kernels) == 1, ( + f"Expected exactly 1 cast_transpose kernel, got {len(ct_kernels)}: {ct_kernels}" + ) + finally: + if old_val is None: + os.environ.pop(env_key, None) + else: + os.environ[env_key] = old_val diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 8fec36849..f4253026e 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -262,13 +262,10 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu // Choose between runtime-compiled or statically-compiled kernel const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); -#ifdef __HIP_PLATFORM_AMD__ - // do_general_config means using the cost model like NVTE to generate kernel configs - bool do_general_config = false; -#endif if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel #ifdef __HIP_PLATFORM_AMD__ - do_general_config = true; + // Whether to fall back to the cost-model-based RTC kernel selection + bool fallback_to_cost_model_rtc = true; // even if we enforce to use OPTIMIZED_HIPIFIED_CAST_TRANSPOSE, may fall back to general kernel configs from NVTE cost model bool nvte_use_optimized_hipified_cast_transpose = false; if (const char* env_p = std::getenv("NVTE_USE_OPTIMIZED_HIPIFIED_CAST_TRANSPOSE") ) { @@ -318,9 +315,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu size_t num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements); size_t rtc_block_size = THREADS_PER_WARP * wpt_size; - do_general_config =!(row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0); + fallback_to_cost_model_rtc =!(row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0); - if(!do_general_config){ + if(!fallback_to_cost_model_rtc){ // Compile NVRTC kernel if needed and launch auto &rtc_manager = rtc::KernelManager::instance(); const std::string kernel_label = concat_strings( @@ -352,8 +349,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu num_rows); } } - } - if(do_general_config){ + if(fallback_to_cost_model_rtc){ #endif // Pick kernel config std::vector kernel_configs; @@ -412,6 +408,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu static_cast(output.scale.dptr), static_cast(output.amax.dptr), static_cast(output.scale_inv.dptr), row_length, num_rows); +#ifdef __HIP_PLATFORM_AMD__ + } +#endif } else { // Statically-compiled general kernel constexpr size_t load_size = 4; constexpr size_t store_size = 4; From e48f96240f12494997956463656c8692d6d0a858 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 19 Apr 2026 21:58:50 -0500 Subject: [PATCH 03/14] PR545 hot fix for un-addressed reviewer comments (#549) * [ROCm] fix the bug in hipfied optimized cast tranpose flow that two kernels got run * [ROCm] move the if(fallback_to_cost_model_rtc) branch into the upstream rtc branch * [ROCm] address reviewer comments left in PR545 (cherry picked from commit 789035adbd08bd0ca040b38617e8e61680f5a7d9) --- ci/pytorch.sh | 2 +- .../test_sanity_hipified_cast_transpose.py | 51 ++++++++----------- 2 files changed, 22 insertions(+), 31 deletions(-) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index 0d322f6c6..34ff01d3c 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -64,8 +64,8 @@ run_test_config(){ run_default_fa 1 test_permutation.py run_default_fa 1 test_recipe.py run 1 test_sanity.py - run_default_fa 1 test_sanity_import.py run_default_fa 3 test_sanity_hipified_cast_transpose.py + run_default_fa 1 test_sanity_import.py run_default_fa 1 attention/test_attention.py # Backend selection is controlled by the test run_default_fa 1 attention/test_cp_utils.py run_default_fa 1 attention/test_kv_cache.py diff --git a/tests/pytorch/test_sanity_hipified_cast_transpose.py b/tests/pytorch/test_sanity_hipified_cast_transpose.py index cedaa3c68..0a62ea30d 100644 --- a/tests/pytorch/test_sanity_hipified_cast_transpose.py +++ b/tests/pytorch/test_sanity_hipified_cast_transpose.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information """ @@ -11,7 +11,6 @@ doubling the GPU work. """ -import os import pytest import torch from torch.profiler import profile, ProfilerActivity @@ -35,7 +34,7 @@ def _fill_uniform(shape, dtype): @pytest.mark.parametrize("in_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("out_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("use_hipified_env", [False, True]) -def test_single_kernel_dispatch(shape, in_dtype, out_dtype, use_hipified_env): +def test_single_kernel_dispatch(shape, in_dtype, out_dtype, use_hipified_env, monkeypatch): """ Verify that tex.quantize dispatches exactly one cast_transpose GPU kernel. Tests both the default cost-model RTC path and the PR #89 hipified path @@ -46,36 +45,28 @@ def test_single_kernel_dispatch(shape, in_dtype, out_dtype, use_hipified_env): amax = torch.zeros(1, dtype=torch.float32, device="cuda") env_key = "NVTE_USE_OPTIMIZED_HIPIFIED_CAST_TRANSPOSE" - old_val = os.environ.get(env_key) + if use_hipified_env: + monkeypatch.setenv(env_key, "1") + else: + monkeypatch.setenv(env_key, "0") - try: - if use_hipified_env: - os.environ[env_key] = "1" - else: - os.environ.pop(env_key, None) + # Warmup (also triggers hipRTC compilation) + q = Float8Quantizer(scale=scale.clone(), amax=amax.clone(), fp8_dtype=out_dtype) + tex.quantize(input_tensor, q) + torch.cuda.synchronize() - # Warmup (also triggers hipRTC compilation) - q = Float8Quantizer(scale=scale.clone(), amax=amax.clone(), fp8_dtype=out_dtype) + # Profiled run + q = Float8Quantizer(scale=scale.clone(), amax=amax.clone(), fp8_dtype=out_dtype) + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: tex.quantize(input_tensor, q) torch.cuda.synchronize() - # Profiled run - q = Float8Quantizer(scale=scale.clone(), amax=amax.clone(), fp8_dtype=out_dtype) - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - tex.quantize(input_tensor, q) - torch.cuda.synchronize() + ct_kernels = [ + evt.name for evt in prof.events() + if evt.device_type == torch.autograd.DeviceType.CUDA + and "cast_transpose" in evt.name + ] - ct_kernels = [ - evt.name for evt in prof.events() - if evt.device_type == torch.autograd.DeviceType.CUDA - and "cast_transpose" in evt.name - ] - - assert len(ct_kernels) == 1, ( - f"Expected exactly 1 cast_transpose kernel, got {len(ct_kernels)}: {ct_kernels}" - ) - finally: - if old_val is None: - os.environ.pop(env_key, None) - else: - os.environ[env_key] = old_val + assert len(ct_kernels) == 1, ( + f"Expected exactly 1 cast_transpose kernel, got {len(ct_kernels)}: {ct_kernels}" + ) From ee1e3c59f68339203fde39e4b32c8e7fabd4a305 Mon Sep 17 00:00:00 2001 From: ipanfilo <145064111+ipanfilo@users.noreply.github.com> Date: Tue, 28 Apr 2026 01:26:10 -0400 Subject: [PATCH 04/14] Ipanfilo/wheel build action (#529) * GHA to build release wheel set * Suppress verbose logging from AOTriton build * Decrease verbosity of hipification (cherry picked from commit e6b79aff67f009dc38e58d799714b874d288e007) --- .github/workflows/rocm-wheels-build.yml | 233 ++++++++++++++++++ build_tools/hipify/hipify.py | 1 + .../common/aotriton/CMakeLists.txt | 4 + 3 files changed, 238 insertions(+) create mode 100644 .github/workflows/rocm-wheels-build.yml diff --git a/.github/workflows/rocm-wheels-build.yml b/.github/workflows/rocm-wheels-build.yml new file mode 100644 index 000000000..c1a8ea087 --- /dev/null +++ b/.github/workflows/rocm-wheels-build.yml @@ -0,0 +1,233 @@ +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. + +# GitHub Actions workflow to build TransformerEngine release wheels for ROCm +# on a manylinux_2_28_x86_64 base image. +# +# Single-job pipeline that: +# 1. Checks Harbor for existing ROCm Docker image or builds new one +# 2. Runs the image to compile and package TE wheels +# + +name: ROCm Wheels Build + +on: + workflow_dispatch: + inputs: + use_local_version: + description: 'Use local version suffix for generated wheels.' + type: boolean + required: false + default: true + use_prebuilt_aiter: + description: 'Use precompiled aiter instead of building from source.' + type: boolean + required: false + default: true + rocm_repo_url: + description: 'ROCm TheRock repository URL used inside the Docker image.' + type: string + required: true + default: 'https://repo.amd.com/rocm/packages/rhel8/x86_64/' + force_build_image: + description: 'Force building of ROCm Docker image, skip checking intermediate storage.' + type: boolean + required: false + default: true + upload_image: + description: 'Upload built ROCm Docker image to intermediate storage.' + type: boolean + required: false + default: false + docker_image_tag_override: + description: 'Override the auto-generated ROCm Docker image tag.' + type: string + required: false + default: '' + workflow_call: + inputs: + use_local_version: + type: boolean + default: true + use_prebuilt_aiter: + type: boolean + default: true + rocm_repo_url: + type: string + default: 'https://repo.amd.com/rocm/packages/rhel8/x86_64/' + force_build_image: + type: boolean + default: true + upload_image: + type: boolean + default: false + docker_image_tag_override: + type: string + default: '' + +env: + DOCKER_IMAGE_NAME: te-rocm-manylinux-x86 + MANYLINUX_PLATFORM: manylinux_2_28_x86_64 + +# ───────────────────────────────────────────────────────────────────────────── +jobs: + + build-rocm-wheels: + name: Build ROCm Docker image and TransformerEngine wheels + runs-on: build-only-te + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Initialize required submodules + run: | + git submodule update --init --recursive --depth 1 \ + 3rdparty/aotriton \ + 3rdparty/aiter \ + 3rdparty/QoLA \ + 3rdparty/hipify_torch + + - name: Derive Docker image tag + id: set-tag + run: | + # Harbor registry configuration + HARBOR_REGISTRY="registry-sc-harbor.amd.com" + HARBOR_PROJECT="framework/te-ci" + + # Build the image tag + if [ -n "${{ inputs.docker_image_tag_override }}" ]; then + # Use override if provided + IMAGE_TAG="${{ inputs.docker_image_tag_override }}" + echo "Using override tag: ${IMAGE_TAG}" + else + # Get current year-month + YEAR_MONTH=$(date +"%Y%m") + + # Try to detect ROCm revision from repomd.xml + ROCM_URL="${{ inputs.rocm_repo_url }}" + ROCM_REVISION=$(curl -s "${ROCM_URL}/repodata/repomd.xml" 2>/dev/null | grep -oP '(?<=)[^<]+' | head -n1 || echo "") + + if [ -n "$ROCM_REVISION" ]; then + # Use year-month with ROCm revision + IMAGE_TAG="${YEAR_MONTH}-${ROCM_REVISION}" + echo "Generated tag with ROCm revision: ${IMAGE_TAG}" + else + # Use year-month only (fallback) + IMAGE_TAG="${YEAR_MONTH}" + echo "Generated tag without ROCm revision (failed to detect): ${IMAGE_TAG}" + fi + fi + + # Set local and Harbor image tags + LOCAL_IMAGE_TAG="${{ env.DOCKER_IMAGE_NAME }}:${IMAGE_TAG}" + HARBOR_IMAGE_TAG="${HARBOR_REGISTRY}/${HARBOR_PROJECT}/${{ env.DOCKER_IMAGE_NAME }}:${IMAGE_TAG}" + + echo "Image tag: ${LOCAL_IMAGE_TAG}" + + echo "image_tag=${LOCAL_IMAGE_TAG}" >> "$GITHUB_OUTPUT" + echo "harbor_image_tag=${HARBOR_IMAGE_TAG}" >> "$GITHUB_OUTPUT" + echo "harbor_registry=${HARBOR_REGISTRY}" >> "$GITHUB_OUTPUT" + + - name: Check for existing image in Harbor + id: check-harbor + if: ${{ !inputs.force_build_image }} + run: | + echo "Checking if image exists in Harbor: ${{ steps.set-tag.outputs.harbor_image_tag }}" + + # Try to pull the image from Harbor (no authentication needed for pull) + if docker pull "${{ steps.set-tag.outputs.harbor_image_tag }}" 2>/dev/null; then + echo "Image found in Harbor, will reuse" + # Tag the Harbor image with local tag for consistency + docker tag "${{ steps.set-tag.outputs.harbor_image_tag }}" "${{ steps.set-tag.outputs.image_tag }}" + echo "image_exists=true" >> "$GITHUB_OUTPUT" + else + echo "Image not found in Harbor, will build new image" + echo "image_exists=false" >> "$GITHUB_OUTPUT" + fi + + # The build context is build_tools/wheel_utils/ because the Dockerfile + # COPYs build_wheels.sh from that directory. + - name: Build Docker image + if: ${{ inputs.force_build_image || steps.check-harbor.outputs.image_exists != 'true' }} + run: | + echo "Building Docker image: ${{ steps.set-tag.outputs.image_tag }}" + docker build \ + --no-cache \ + --build-arg ROCM_REPO_URL="${{ inputs.rocm_repo_url }}" \ + --tag "${{ steps.set-tag.outputs.image_tag }}" \ + --file build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 \ + build_tools/wheel_utils/ + + - name: Upload Docker image to Harbor + if: ${{ inputs.upload_image && (inputs.force_build_image || steps.check-harbor.outputs.image_exists != 'true') }} + env: + HARBOR_USERNAME: ${{ secrets.HARBOR_USERNAME }} + HARBOR_PASSWORD: ${{ secrets.HARBOR_PASSWORD }} + run: | + HARBOR_REGISTRY="${{ steps.set-tag.outputs.harbor_registry }}" + + echo "Logging in to Harbor registry: ${HARBOR_REGISTRY}" + echo "$HARBOR_PASSWORD" | docker login "${HARBOR_REGISTRY}" -u "$HARBOR_USERNAME" --password-stdin + + echo "Tagging image for Harbor" + docker tag "${{ steps.set-tag.outputs.image_tag }}" "${{ steps.set-tag.outputs.harbor_image_tag }}" + + echo "Pushing image to Harbor: ${{ steps.set-tag.outputs.harbor_image_tag }}" + docker push "${{ steps.set-tag.outputs.harbor_image_tag }}" + + echo "Logging out from Harbor" + docker logout "${HARBOR_REGISTRY}" + + - name: Create wheelhouse directory + run: mkdir -p "${{ runner.temp }}/wheelhouse" + + # Mount the checked-out workspace to /TransformerEngine in the container. + # The container writes all wheels and logs under /wheelhouse. + - name: Build TransformerEngine wheels + run: | + NVTE_AITER_PREBUILT_BASE_URL="https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/aiter-prebuilts" + docker run --rm \ + --env LOCAL_TREE_BUILD=1 \ + --env NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD=1 \ + --env TARGET_BRANCH="${{ github.ref_name }}" \ + ${{ inputs.use_local_version && '--env NVTE_USE_LOCAL_VERSION=1' || '' }} \ + ${{ inputs.use_prebuilt_aiter && '--env NVTE_AITER_PREBUILT_BASE_URL="${NVTE_AITER_PREBUILT_BASE_URL}"' || '' }} \ + --volume "${{ github.workspace }}:/TransformerEngine" \ + --volume "${{ runner.temp }}/wheelhouse:/wheelhouse" \ + "${{ steps.set-tag.outputs.image_tag }}" + + - name: List built artifacts + if: always() + run: | + echo "=== Wheels and source distributions ===" + find "${{ runner.temp }}/wheelhouse" \ + \( -name "*.whl" -o -name "*.tar.gz" \) | sort + + echo "" + echo "=== Build logs ===" + LOG_DIR="${{ runner.temp }}/wheelhouse/logs" + if [ -d "$LOG_DIR" ]; then + find "$LOG_DIR" -type f | sort + fi + + - name: Upload wheels as GitHub Actions artifacts + if: success() + uses: actions/upload-artifact@v4 + with: + name: te-rocm-wheels + path: | + ${{ runner.temp }}/wheelhouse/*.whl + ${{ runner.temp }}/wheelhouse/*.tar.gz + retention-days: 1 + if-no-files-found: error + + - name: Upload build logs on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: te-rocm-build-logs + path: ${{ runner.temp }}/wheelhouse/logs/ + retention-days: 30 + if-no-files-found: warn diff --git a/build_tools/hipify/hipify.py b/build_tools/hipify/hipify.py index e295c5be9..57dfeebc6 100644 --- a/build_tools/hipify/hipify.py +++ b/build_tools/hipify/hipify.py @@ -59,6 +59,7 @@ def do_hipify(te_root: Union[Path, str], src_dir: Union[Path, str], is_pytorch_extension=True, hipify_extra_files_only=False, show_detailed=False, + show_progress=False, no_math_replace=True) # Convert hipify objects to dictionaries for consistent behavior diff --git a/transformer_engine/common/aotriton/CMakeLists.txt b/transformer_engine/common/aotriton/CMakeLists.txt index c4be66106..e5da103a2 100644 --- a/transformer_engine/common/aotriton/CMakeLists.txt +++ b/transformer_engine/common/aotriton/CMakeLists.txt @@ -57,6 +57,8 @@ if(NOT DEFINED AOTRITON_PATH) "${__AOTRITON_INSTALL_DIR}" BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__" + LOG_DOWNLOAD ON # redirect download stdout to log file + LOG_OUTPUT_ON_FAILURE ON # but still show output if build fails ) message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.") endfunction() @@ -79,6 +81,8 @@ if(NOT DEFINED AOTRITON_PATH) -DTE_AOTRITON_COMMIT_SHA1=${AOTRITON_SHA} -DCMAKE_PROJECT_INCLUDE=${CMAKE_CURRENT_LIST_DIR}/aotriton_custom.cmake BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so" + LOG_BUILD ON # redirect build stdout to log file + LOG_OUTPUT_ON_FAILURE ON # but still show output if build fails ) message(STATUS "Adding AOTriton library.") add_dependencies(aotriton aotriton_external) From 49100b732a9aa49df439d15978c670a11e524c14 Mon Sep 17 00:00:00 2001 From: Leo Date: Wed, 29 Apr 2026 10:12:18 +0200 Subject: [PATCH 05/14] CI: Refactor ROCm CI to use GPU-sized runners and build-only jobs (#528) * CI: Refactor ROCm CI to use GPU-sized runners and build-only jobs * Update labels * Shallow clone * Address comments * Add a missing submodule * Address comments * Cleanup * Address comments * Fix NVTE_FRAMEWORK (cherry picked from commit 6b96c4692ffa3282518fff3f54dc7ee7a8e04cf1) --- .github/workflows/aiter-prebuilt-upload.yml | 6 +- .github/workflows/rocm-ci.yml | 566 ++++++++++---------- 2 files changed, 277 insertions(+), 295 deletions(-) diff --git a/.github/workflows/aiter-prebuilt-upload.yml b/.github/workflows/aiter-prebuilt-upload.yml index f9d2a91d7..a45350a79 100644 --- a/.github/workflows/aiter-prebuilt-upload.yml +++ b/.github/workflows/aiter-prebuilt-upload.yml @@ -13,7 +13,7 @@ on: jobs: upload: - runs-on: linux-te-mi325-8 + runs-on: build-only-te steps: - name: Checkout source uses: actions/checkout@v6 @@ -44,11 +44,7 @@ jobs: --rm \ --name te-aiter-upload \ --network=host \ - --device=/dev/dri --device=/dev/kfd \ - --shm-size=16G \ --pid=host \ - --group-add $(getent group render | cut -d: -f3) \ - --group-add $(getent group video | cut -d: -f3) \ -v "${{ github.workspace }}:/workspace" \ -w /workspace \ ${{ steps.cfg.outputs.image }} diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index c85f1bed2..5e0ae242c 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -45,132 +45,107 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true +env: + TEST_LEVEL: ${{ (github.event_name == 'push' && '3') || inputs.test_level || '1' }} + jobs: - build_and_test: - name: Build and Test on GPU (${{ matrix.runner }}) - Level ${{ (github.event_name == 'push' && '3') || inputs.test_level || '1' }} - timeout-minutes: 720 - runs-on: ${{ matrix.runner }} - strategy: - fail-fast: false - matrix: - runner: [linux-te-mi325-8, linux-te-mi35x-8] + select_image: + name: Select Docker Image + runs-on: ubuntu-latest + timeout-minutes: 10 + outputs: + image-tag: ${{ steps.select-image.outputs.image-tag }} steps: - name: Checkout repository uses: actions/checkout@v6 with: - submodules: 'recursive' - - - name: Host Diagnostics & Environment Setup - id: host-setup - run: | - # Host Activity Checks - echo "::group::Host Diagnostics" - - echo ">>> Active Containers:" - docker ps -a - - echo ">>> ROCm Installation:" - (ls -d /opt/rocm/core-* || ls -d /opt/rocm-* || echo "No default ROCm path found") 2>/dev/null || true - echo ">>> GPU info:" - ls -l /dev/dri - ls -l /dev/kfd - rocm-smi - - echo ">>> Kernel Command Line:" - cat /proc/cmdline - echo "::endgroup::" - - # Calculate Test Level - # Default to input (or '1' if input is missing/null) - CALC_LEVEL="${{ inputs.test_level || '1' }}" - - # Only force Level 3 if this is a direct PUSH to dev or a release branch - if [[ "${{ github.event_name }}" == "push" ]]; then - echo "::notice::Push to monitored branch (${{ github.ref_name }}) detected. Forcing Level 3." - CALC_LEVEL="3" - fi - - echo "TEST_LEVEL=$CALC_LEVEL" >> $GITHUB_ENV - - # Print Final Environment - echo "::group::Environment & Parameters" - echo "Final Test Level: $CALC_LEVEL" - echo "Event Name: ${{ github.event_name }}" - echo "Ref Name: ${{ github.ref_name }}" - echo "Base Ref: ${{ github.base_ref }}" - env | sort - echo "::endgroup::" + ref: ${{ inputs.test_config_from_source && github.ref_name || github.event.repository.default_branch || 'dev' }} + sparse-checkout: ci/ci_config.json + sparse-checkout-cone-mode: false - name: Select Docker Image Tag id: select-image run: | - # Determine config source - # Default we are fetching from 'dev' branch - CONFIG_BRANCH="dev" - - # If manual run requesting source config, switch branch if [[ "${{ inputs.test_config_from_source }}" == "true" ]]; then - CONFIG_BRANCH="${{ github.ref_name }}" - echo "::notice::Debugging mode: Fetching config from current branch ($CONFIG_BRANCH)" + echo "::notice::Debugging mode: Using ci/ci_config.json from ${{ github.ref_name }}" + else + echo "::notice::Using ci/ci_config.json from ${{ github.event.repository.default_branch || 'dev' }}" fi - # Download config - CONFIG_URL="https://raw.githubusercontent.com/ROCm/TransformerEngine/${CONFIG_BRANCH}/ci/ci_config.json" - echo "Attempting to fetch image config from: $CONFIG_URL" - - if curl -s -f -o docker_config.json "$CONFIG_URL"; then - echo "Successfully downloaded config from $CONFIG_BRANCH." - else - echo "::warning::Failed to fetch config from $CONFIG_BRANCH (File might not exist yet)." - - # Fallback: Check source branch file - if [[ -f "ci/ci_config.json" ]]; then - echo "::notice::Falling back to local 'ci/ci_config.json' from checkout." - cp ci/ci_config.json docker_config.json - else - echo "::error::Config file not found in $CONFIG_BRANCH OR locally." - exit 1 - fi + if [[ ! -f "ci/ci_config.json" ]]; then + echo "::error::Config file not found in checkout." + exit 1 fi - # Determine image key BRANCH_NAME="${{ github.base_ref || github.ref_name }}" echo "Determining image for branch: $BRANCH_NAME" - - # Logic: Check if branch matches "release_vX.X". - # If so, look for that key in JSON. Otherwise default. - JSON_KEY="default" - - if [[ $BRANCH_NAME =~ ^release_v([0-9]+\.[0-9]+)_rocm$ ]]; then - VERSION_KEY="release_v${BASH_REMATCH[1]}" - # Check if this specific version key exists in the JSON - if [[ $(jq "(.docker_images | has(\"$VERSION_KEY\"))" docker_config.json) == "true" ]]; then - JSON_KEY="$VERSION_KEY" - fi + VERSION_KEY="$BRANCH_NAME" + + if jq -e --arg key "$VERSION_KEY" '.docker_images[$key]' ci/ci_config.json > /dev/null; then + JSON_KEY="$VERSION_KEY" + else + JSON_KEY="default" fi - - echo "Selected config key: $JSON_KEY" - # Extract image name from json - IMAGE_TO_USE=$(jq -r ".docker_images.\"$JSON_KEY\"" docker_config.json) + echo "Selected config key: $JSON_KEY" + IMAGE_TO_USE=$(jq -r --arg key "$JSON_KEY" '.docker_images[$key]' ci/ci_config.json) - # Check input from workflow_dispatch overriding the image MANUAL_OVERRIDE="${{ inputs.docker_image_override }}" if [[ -n "$MANUAL_OVERRIDE" ]]; then echo "::notice::Manual override detected: $MANUAL_OVERRIDE" IMAGE_TO_USE="$MANUAL_OVERRIDE" fi - + echo "Selected image: $IMAGE_TO_USE" echo "image-tag=$IMAGE_TO_USE" >> $GITHUB_OUTPUT + build: + # Delegate wheel building to the reusable workflow on dev. It produces a core .whl plus framework .tar.gz sdists under artifact name `te-rocm-wheels`. + uses: ./.github/workflows/rocm-wheels-build.yml + secrets: inherit + + sgpu_tests: + name: sGPU Tests (${{ matrix.arch_label }}) + needs: [select_image, build] + timeout-minutes: 360 + runs-on: ${{ matrix.arch_label == 'mi30x' && 'linux-te-mi30x-4' || 'linux-te-mi35x-4' }} + strategy: + fail-fast: false + matrix: + arch_label: [mi30x, mi35x] + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Initialize required submodules + run: | + git submodule update --init --recursive --depth 1 \ + 3rdparty/googletest \ + 3rdparty/hipify_torch + + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + name: te-rocm-wheels + path: dist/ + + - name: Host Diagnostics + run: | + echo "::group::Host Diagnostics" + echo ">>> GPU info:" + ls -l /dev/dri + ls -l /dev/kfd + rocm-smi + echo "::endgroup::" + - name: Pull Docker Image run: | - docker pull ${{ steps.select-image.outputs.image-tag }} + docker pull ${{ needs.select_image.outputs.image-tag }} - name: Run Container run: | docker run -dt \ + --rm \ --name te-runner \ --network=host \ --device=/dev/dri --device=/dev/kfd \ @@ -180,257 +155,268 @@ jobs: --group-add $(getent group video | cut -d: -f3) \ -v "${{ github.workspace }}:/workspace" \ -w /workspace \ - ${{ steps.select-image.outputs.image-tag}} + ${{ needs.select_image.outputs.image-tag }} - - name: Container Diagnostics & GPU Setup - id: container-diag + - name: Install packages run: | - echo "::group::Container Configuration" - # Check Shared Memory Size inside container - echo ">>> /dev/shm size:" - docker exec te-runner df -h /dev/shm - - # Check OS/Kernel inside container - echo ">>> Container OS:" - docker exec te-runner cat /etc/os-release | grep PRETTY_NAME - echo "::endgroup::" - - echo "::group::ROCm Diagnostics (Host vs Container)" - echo ">>> CONTAINER rocm-smi:" - docker exec te-runner rocm-smi || true - echo "::endgroup::" - - # Determine Architecture - # Run rocminfo inside the container and capture the output - ARCH=$(docker exec te-runner bash -c "rocminfo | grep -m 1 -oP 'gfx[0-9a-fA-F]+'") - - if [ -z "$ARCH" ]; then - echo "::error::Could not determine GPU architecture using rocminfo inside the container." - docker exec te-runner rocminfo - exit 1 - fi - - echo "Detected GPU Arch: $ARCH" - echo "arch=$ARCH" >> $GITHUB_OUTPUT - - - name: Build Project - run: | - docker exec \ - -e GPU_ARCH=${{ steps.container-diag.outputs.arch }} \ - te-runner bash -c "$(cat <<'EOF' + docker exec te-runner bash -c "$(cat <<'EOF' set -ex - - export HIP_PATH="" - export PYTORCH_ROCM_ARCH=$GPU_ARCH - export NVTE_ROCM_ARCH=$GPU_ARCH - export NVTE_AITER_PREBUILT_BASE_URL=https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/aiter-prebuilts - pip install ninja + # core (cpp) tests build via cmake inside the repo; allow git ops in-tree. git config --global --add safe.directory '*' - pip install --no-build-isolation -v . 2>&1 + + TE_CORE_PKG=$(find /workspace/dist -type f -name 'transformer_engine_rocm[0-9]*.whl' | sort | head -n 1) + TE_TORCH_PKG=$(find /workspace/dist -type f -name 'transformer_engine_rocm_torch*.tar.gz' | sort | head -n 1) + TE_JAX_PKG=$(find /workspace/dist -type f -name 'transformer_engine_rocm_jax*.tar.gz' | sort | head -n 1) + test -n "$TE_CORE_PKG" && test -n "$TE_TORCH_PKG" && test -n "$TE_JAX_PKG" + + pip install --no-deps "$TE_CORE_PKG" + pip install ninja pybind11[global] + pip install --no-build-isolation --no-deps "$TE_TORCH_PKG" + pip install --no-build-isolation --no-deps "$TE_JAX_PKG" EOF )" - - name: Run sGPU tests - id: sgpu-tests - continue-on-error: true + - name: Run sGPU tests in parallel (pytorch, jax, examples, core) + id: run-tests + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | - # Cleanup previous failure markers if any. Don't actually do anything on k8s pods rm -f FAIL_* docker exec \ -e TEST_SGPU=1 \ -e TEST_LEVEL=${{ env.TEST_LEVEL }} \ + -e HF_TOKEN="$HF_TOKEN" \ te-runner bash -c "$(cat <<'EOF' #!/usr/bin/bash set -x -o pipefail ulimit -c 0 # Disable core dumps - HIP_VISIBLE_DEVICES=1 ci/pytorch.sh > /workspace/torch_sgpu.log 2>&1 & - torch_pid=$!; echo Pytorch test pid $! - - HIP_VISIBLE_DEVICES=2 ci/jax.sh > /workspace/jax_sgpu.log 2>&1 & - jax_pid=$!; echo JAX test pid $! - - HIP_VISIBLE_DEVICES=3 ci/core.sh > /workspace/core_sgpu.log 2>&1 & - core_pid=$!; echo Core test pid $! - - wait $core_pid; core_rc=$? - wait $jax_pid; jax_rc=$? - wait $torch_pid; torch_rc=$? - - # /workspace/FAIL_* files are for failure markers we can extract to the host runner and process later - # Check PyTorch - if [ $torch_rc -ne 0 ]; then - echo "::group::[FAILED] PyTorch sGPU Log" - cat /workspace/torch_sgpu.log - echo "::endgroup::" - echo "::error::Pytorch sGPU test FAILED." - touch /workspace/FAIL_TORCH_SGPU - fi + HIP_VISIBLE_DEVICES=0 ci/pytorch.sh > /workspace/torch.log 2>&1 & + TORCH_PID=$! + + HIP_VISIBLE_DEVICES=1 ci/jax.sh > /workspace/jax.log 2>&1 & + JAX_PID=$! + + ( + set -e + python -c "import os; print('HF_TOKEN set:', bool(os.environ.get('HF_TOKEN')))" + + JAX_CONSTRAINTS=/tmp/jax-constraints.txt + pip freeze | grep -iE '^(jax|jaxlib|jax[_-]rocm|jax[_-]plugins)[=@]' > "$JAX_CONSTRAINTS" || true - # Check JAX - if [ $jax_rc -ne 0 ]; then - echo "::group::[FAILED] JAX sGPU Log" - cat /workspace/jax_sgpu.log + export HIP_VISIBLE_DEVICES=2 + + cd /workspace/examples/pytorch/mnist + python main.py + python main.py --use-te + python main.py --use-fp8 + + cd /workspace/examples/jax/mnist + pip3 install -c "$JAX_CONSTRAINTS" -r requirements.txt + python test_single_gpu_mnist.py + python test_single_gpu_mnist.py --use-te + python test_single_gpu_mnist.py --use-fp8 + + cd /workspace/examples/jax/encoder + pip3 install -c "$JAX_CONSTRAINTS" -r requirements.txt + python test_single_gpu_encoder.py + python test_single_gpu_encoder.py --use-fp8 + ) > /workspace/examples.log 2>&1 & + EXAMPLES_PID=$! + + HIP_VISIBLE_DEVICES=3 ci/core.sh > /workspace/core.log 2>&1 & + CORE_PID=$! + + wait $TORCH_PID; torch_rc=$? + wait $JAX_PID; jax_rc=$? + wait $EXAMPLES_PID; examples_rc=$? + wait $CORE_PID; core_rc=$? + + if [ $torch_rc -ne 0 ]; then + echo "::group::[FAILED] PyTorch Log" + cat /workspace/torch.log echo "::endgroup::" - echo "::error::JAX sGPU test FAILED." - touch /workspace/FAIL_JAX_SGPU + echo "::error::PyTorch tests FAILED." + touch /workspace/FAIL_TORCH fi - # Check Core - if [ $core_rc -ne 0 ]; then - echo "::group::[FAILED] Core sGPU Log" - cat /workspace/core_sgpu.log + if [ $jax_rc -ne 0 ]; then + echo "::group::[FAILED] JAX Log" + cat /workspace/jax.log echo "::endgroup::" - echo "::error::Core sGPU test FAILED." - touch /workspace/FAIL_CORE_SGPU + echo "::error::JAX tests FAILED." + touch /workspace/FAIL_JAX fi - - test $torch_rc -eq 0 -a $jax_rc -eq 0 -a $core_rc -eq 0 - EOF - )" - - # Export failed tests statuses to host runner - if [ -f FAIL_TORCH_SGPU ]; then echo "torch=fail" >> $GITHUB_OUTPUT; fi - if [ -f FAIL_JAX_SGPU ]; then echo "jax=fail" >> $GITHUB_OUTPUT; fi - if [ -f FAIL_CORE_SGPU ]; then echo "core=fail" >> $GITHUB_OUTPUT; fi - - name: Run mGPU tests - id: mgpu-tests - continue-on-error: true - run: | - docker exec \ - -e TEST_MGPU=1 \ - -e TEST_LEVEL=${{ env.TEST_LEVEL }} \ - te-runner bash -c "$(cat <<'EOF' - #!/usr/bin/bash - set -x -o pipefail - ulimit -c 0 # Disable core dumps - - # Run PyTorch - ci/pytorch.sh > /workspace/torch_mgpu.log 2>&1 - torch_rc=$? - - # Run JAX - ci/jax.sh > /workspace/jax_mgpu.log 2>&1 - jax_rc=$? - - # /workspace/FAIL_* files are for failure markers we can extract to the host runner and process later - if [ $torch_rc -ne 0 ]; then - echo "::group::[FAILED] PyTorch mGPU Log" - cat /workspace/torch_mgpu.log + if [ $examples_rc -ne 0 ]; then + echo "::group::[FAILED] Examples Log" + cat /workspace/examples.log echo "::endgroup::" - echo "::error::Pytorch mGPU test FAILED." - touch /workspace/FAIL_TORCH_MGPU + echo "::error::Examples FAILED." + touch /workspace/FAIL_EXAMPLES fi - if [ $jax_rc -ne 0 ]; then - echo "::group::[FAILED] JAX mGPU Log" - cat /workspace/jax_mgpu.log + if [ $core_rc -ne 0 ]; then + echo "::group::[FAILED] Core Log" + cat /workspace/core.log echo "::endgroup::" - echo "::error::JAX mGPU test FAILED." - touch /workspace/FAIL_JAX_MGPU + echo "::error::Core tests FAILED." + touch /workspace/FAIL_CORE fi - - test $torch_rc -eq 0 -a $jax_rc -eq 0 - EOF - )" - # Export failed tests statuses to host runner - if [ -f FAIL_TORCH_MGPU ]; then echo "torch=fail" >> $GITHUB_OUTPUT; fi - if [ -f FAIL_JAX_MGPU ]; then echo "jax=fail" >> $GITHUB_OUTPUT; fi - - - name: Run Examples - id: examples-tests - continue-on-error: true - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - docker exec -e HF_TOKEN="$HF_TOKEN" te-runner bash -c "$(cat <<'EOF' - #!/usr/bin/bash - set -ex -o pipefail - ulimit -c 0 # Disable core dumps - - # Check whether the HF_TOKEN is present - python -c "import os; print('HF_TOKEN set:', bool(os.environ.get('HF_TOKEN')))" - - cd /workspace/examples/pytorch/mnist - python main.py 2>&1 | tee /workspace/examples.log - python main.py --use-te 2>&1 | tee -a /workspace/examples.log - python main.py --use-fp8 2>&1 | tee -a /workspace/examples.log - - cd /workspace/examples/jax/mnist - pip3 install -r requirements.txt - python test_single_gpu_mnist.py 2>&1 | tee -a /workspace/examples.log - python test_single_gpu_mnist.py --use-te 2>&1 | tee -a /workspace/examples.log - python test_single_gpu_mnist.py --use-fp8 2>&1 | tee -a /workspace/examples.log - - cd /workspace/examples/jax/encoder - pip3 install -r requirements.txt - python test_single_gpu_encoder.py 2>&1 | tee -a /workspace/examples.log - python test_single_gpu_encoder.py --use-fp8 2>&1 | tee -a /workspace/examples.log + test $torch_rc -eq 0 -a $jax_rc -eq 0 -a $examples_rc -eq 0 -a $core_rc -eq 0 EOF )" - - name: Check Test Failure Status + - name: Check suite failure status if: always() run: | EXIT_STATUS=0 - # Check outcomes of the specific test steps - # "outcome" will be 'failure' even if continue-on-error was true - - # sGPU CHECKS - # We check for the file existence directly because the 'Run sGPU tests' step - # halts immediately on docker failure, skipping the lines that set step outputs. - if [[ -f FAIL_CORE_SGPU ]]; then - echo "::error::Core sGPU Tests Failed." - EXIT_STATUS=1 - fi - if [[ -f FAIL_TORCH_SGPU ]]; then - echo "::error::PyTorch sGPU Tests Failed." + if [[ -f FAIL_TORCH ]]; then + echo "::error::PyTorch tests failed." EXIT_STATUS=1 fi - if [[ -f FAIL_JAX_SGPU ]]; then - echo "::error::JAX sGPU Tests Failed." + if [[ -f FAIL_JAX ]]; then + echo "::error::JAX tests failed." EXIT_STATUS=1 fi - - # mGPU CHECKS - if [[ -f FAIL_TORCH_MGPU ]]; then - echo "::error::PyTorch mGPU Tests Failed." + if [[ -f FAIL_EXAMPLES ]]; then + echo "::error::Examples failed." EXIT_STATUS=1 fi - if [[ -f FAIL_JAX_MGPU ]]; then - echo "::error::JAX mGPU Tests Failed." + if [[ -f FAIL_CORE ]]; then + echo "::error::Core tests failed." EXIT_STATUS=1 fi + exit $EXIT_STATUS - # EXAMPLES CHECK - # Examples script does not use marker files, so we rely on step outcome - if [[ "${{ steps.examples-tests.outcome }}" == "failure" ]]; then - echo "::error::Example Tests Failed." - EXIT_STATUS=1 - fi + - name: Upload logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: logs-sgpu-${{ matrix.arch_label }} + path: | + *.log + if-no-files-found: ignore + retention-days: 5 + + - name: Cleanup container + if: always() + run: docker rm -f te-runner || true + + mgpu_tests: + name: mGPU ${{ matrix.framework == 'pytorch' && 'Torch' || 'JAX' }} (${{ matrix.arch_label }}) + needs: [select_image, build] + timeout-minutes: 360 + runs-on: ${{ matrix.arch_label == 'mi30x' && 'linux-te-mi30x-8' || 'linux-te-mi35x-8' }} + strategy: + fail-fast: false + matrix: + arch_label: [mi30x, mi35x] + framework: [pytorch, jax] + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + name: te-rocm-wheels + path: dist/ + + - name: Host Diagnostics + run: | + echo "::group::Host Diagnostics" + echo ">>> GPU info:" + ls -l /dev/dri + ls -l /dev/kfd + rocm-smi + echo "::endgroup::" + + - name: Pull Docker Image + run: | + docker pull ${{ needs.select_image.outputs.image-tag }} - # Fail the job if any errors were detected - if [[ "$EXIT_STATUS" == "1" ]]; then - exit 1 + - name: Run Container + run: | + docker run -dt \ + --rm \ + --name te-runner \ + --network=host \ + --device=/dev/dri --device=/dev/kfd \ + --shm-size=16G \ + --pid=host \ + --group-add $(getent group render | cut -d: -f3) \ + --group-add $(getent group video | cut -d: -f3) \ + -v "${{ github.workspace }}:/workspace" \ + -w /workspace \ + ${{ needs.select_image.outputs.image-tag }} + + - name: Install packages + env: + FRAMEWORK: ${{ matrix.framework }} + run: | + docker exec -e FRAMEWORK="$FRAMEWORK" te-runner bash -c "$(cat <<'EOF' + set -ex + + TE_CORE_PKG=$(find /workspace/dist -type f -name 'transformer_engine_rocm[0-9]*.whl' | sort | head -n 1) + if [ "$FRAMEWORK" = "pytorch" ]; then + TE_FW_PKG=$(find /workspace/dist -type f -name 'transformer_engine_rocm_torch*.tar.gz' | sort | head -n 1) + else + TE_FW_PKG=$(find /workspace/dist -type f -name 'transformer_engine_rocm_jax*.tar.gz' | sort | head -n 1) fi + test -n "$TE_CORE_PKG" && test -n "$TE_FW_PKG" - - name: Copy logs and reports from container - if: always() + pip install --no-deps "$TE_CORE_PKG" + pip install ninja pybind11[global] + pip install --no-build-isolation --no-deps "$TE_FW_PKG" + EOF + )" + + - name: Run mGPU tests + id: mgpu-tests + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | - docker cp te-runner:/workspace/torch_sgpu.log ./torch_sgpu.log || true - docker cp te-runner:/workspace/jax_sgpu.log ./jax_sgpu.log || true - docker cp te-runner:/workspace/core_sgpu.log ./core_sgpu.log || true - docker cp te-runner:/workspace/torch_mgpu.log ./torch_mgpu.log || true - docker cp te-runner:/workspace/jax_mgpu.log ./jax_mgpu.log || true + case "${{ matrix.framework }}" in + pytorch) TEST_SCRIPT=ci/pytorch.sh; LOG_FILE=/workspace/torch_mgpu.log; SUITE_NAME=PyTorch ;; + jax) TEST_SCRIPT=ci/jax.sh; LOG_FILE=/workspace/jax_mgpu.log; SUITE_NAME=JAX ;; + *) echo "::error::Unknown framework: ${{ matrix.framework }}"; exit 1 ;; + esac + + docker exec \ + -e TEST_MGPU=1 \ + -e TEST_LEVEL=${{ env.TEST_LEVEL }} \ + -e TEST_SCRIPT=$TEST_SCRIPT \ + -e LOG_FILE=$LOG_FILE \ + -e SUITE_NAME=$SUITE_NAME \ + -e NVTE_FRAMEWORK=${{ matrix.framework }} \ + -e HF_TOKEN="$HF_TOKEN" \ + te-runner bash -c "$(cat <<'EOF' + #!/usr/bin/bash + set -x -o pipefail + ulimit -c 0 # Disable core dumps + + "$TEST_SCRIPT" > "$LOG_FILE" 2>&1 + test_rc=$? + + if [ $test_rc -ne 0 ]; then + echo "::group::[FAILED] ${SUITE_NAME} mGPU Log" + cat "$LOG_FILE" + echo "::endgroup::" + echo "::error::${SUITE_NAME} mGPU tests FAILED." + fi + + exit $test_rc + EOF + )" - - name: Upload logs and test reports + - name: Upload logs if: always() uses: actions/upload-artifact@v4 with: - name: logs-and-reports-${{ matrix.runner }} + name: logs-mgpu-${{ matrix.arch_label }}-${{ matrix.framework }} path: | *.log if-no-files-found: ignore From d48aac8618525725127e37881bde34c260fe596d Mon Sep 17 00:00:00 2001 From: omkar kakarparthi <75638701+okakarpa@users.noreply.github.com> Date: Thu, 11 Jun 2026 12:26:54 -0500 Subject: [PATCH 06/14] rocm-ci: scope test container to pod-allocated GPUs (#611) * rocm-ci: scope test container to pod-allocated GPUs via podinfo The sGPU/mGPU jobs launched the test container with '--device=/dev/dri --device=/dev/kfd', exposing ALL host GPUs to the nested (privileged-dind) container regardless of the GPUs Kubernetes allocated to the pod. Combined with the hard-coded absolute HIP_VISIBLE_DEVICES=0..3, two jobs co-scheduled on the same node both pinned physical GPUs 0-3 and collided (OOM/hangs/test failures) while 4-7 sat idle. Jobs only passed when the node was otherwise idle -- arch-independent (mi300x and mi35x). Build GPU_FLAG from /etc/podinfo/gha-render-devices, which the runner populates with this pod's allocated '--device /dev/dri/renderD*' flags (falls back to all GPUs on bare metal). /dev/kfd is always passed. The container now sees only its allocated GPUs as 0..N-1, so the per-suite HIP_VISIBLE_DEVICES=0/1/2/3 split is correct and collision-free across co-scheduled pods. Requires the runner ScaleSet to populate /etc/podinfo/gha-render-devices (see companion rocOps change). --------- Co-authored-by: leo-automation (cherry picked from commit e66f431af70df4bd73d996da51842ef512fb5051) --- .github/workflows/rocm-ci.yml | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 5e0ae242c..9cf8b90cc 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -142,13 +142,16 @@ jobs: run: | docker pull ${{ needs.select_image.outputs.image-tag }} - - name: Run Container + - &run_container + name: Run Container run: | + DEVICE_FLAG="$(cat /etc/podinfo/gha-render-devices 2>/dev/null || echo --device=/dev/dri)" + test -n "$DEVICE_FLAG" docker run -dt \ --rm \ --name te-runner \ --network=host \ - --device=/dev/dri --device=/dev/kfd \ + --device=/dev/kfd $DEVICE_FLAG \ --shm-size=16G \ --pid=host \ --group-add $(getent group render | cut -d: -f3) \ @@ -338,20 +341,7 @@ jobs: run: | docker pull ${{ needs.select_image.outputs.image-tag }} - - name: Run Container - run: | - docker run -dt \ - --rm \ - --name te-runner \ - --network=host \ - --device=/dev/dri --device=/dev/kfd \ - --shm-size=16G \ - --pid=host \ - --group-add $(getent group render | cut -d: -f3) \ - --group-add $(getent group video | cut -d: -f3) \ - -v "${{ github.workspace }}:/workspace" \ - -w /workspace \ - ${{ needs.select_image.outputs.image-tag }} + - *run_container - name: Install packages env: From 70248de2bbbd31f92346b8bb8807ecdf4831d633 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 11:49:21 -0500 Subject: [PATCH 07/14] [ROCm] add the bias all row -inf support for jax unfused-attn (#556) * [ROCm] add the bias all row -inf support for jax unfused-attn * [ROCm] address reviewer comments and fix the pytest failure * [ROCm] add the ck guard to newly added all row -inf bias test * Update tests/jax/test_fused_attn.py Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com> --------- Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com> (cherry picked from commit 001807b9c27c93bee4a6a7c8ce45832c6d3efee8) --- tests/jax/test_fused_attn.py | 114 +++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 26f5514d3..68a855161 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -101,6 +101,13 @@ def general_dot_product_attention( logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv)) # apply post-scale bias logits = logits + bias + # [ROCm] Detect query rows where ALL bias values are -inf (fully masked out). + # These rows would produce NaN in softmax; zero logits to prevent NaN since + # the softmax output for these rows is zeroed out below anyway. + # Use equality check against -inf so real-valued bias (e.g. 1HSS) is unaffected. + if is_hip_extension(): + bias_all_neg_mask = jnp.all(bias == -jnp.inf, axis=-1, keepdims=True) + logits = jnp.where(bias_all_neg_mask, 0, logits) # reshape logits back to original logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv)) @@ -129,6 +136,15 @@ def general_dot_product_attention( case _: raise NotImplementedError(f"Unknown {softmax_type=}") + # [ROCm] Zero out softmax for fully-masked rows to prevent NaN propagation in backward + # Reuses bias_all_neg_mask from the pre-softmax block above to avoid drift. + if bias is not None and is_hip_extension(): + # Reshape 4D mask to 5D to match softmax_out (b, h_kv, num_groups, s_q, s_kv) + ms = bias_all_neg_mask.shape # (batch_or_1, h_or_1, s_q, 1) + hkv = min(ms[1], h_kv) # 1 for B1SS/11SS, h_kv for 1HSS + mask_5d = bias_all_neg_mask.reshape(ms[0], hkv, -1, ms[2], ms[3]) + softmax_out = jnp.where(mask_5d, 0, softmax_out) + if not deterministic and dropout_rate > 0.0: keep_prob = 1.0 - dropout_rate keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape) @@ -1038,6 +1054,12 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): if self.dropout_prob > 0.0: return + # [ROCm] Verify no NaN or Inf in forward outputs + if is_hip_extension(): + for name, out in [("Fused", primitive_out), ("Reference", reference_out)]: + assert not jnp.any(jnp.isnan(out)), f"{name} attention output contains NaN" + assert not jnp.any(jnp.isinf(out)), f"{name} attention output contains Inf" + print_debug_tensor_stats(f"primitive_out", primitive_out) print_debug_tensor_stats(f"reference_grad_valid", reference_out) print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out)) @@ -1065,6 +1087,17 @@ def check_dqkv(primitive, reference, pad, idx): primitive_dk = self.cp_inverse_reorder_fn(primitive_dk) primitive_dv = self.cp_inverse_reorder_fn(primitive_dv) + # [ROCm] Verify no NaN or Inf in gradients + if is_hip_extension(): + for grad_name, p_grad, r_grad in [ + ("dq", primitive_dq, reference_dq), + ("dk", primitive_dk, reference_dk), + ("dv", primitive_dv, reference_dv), + ]: + for src_name, grad in [("Fused", p_grad), ("Reference", r_grad)]: + assert not jnp.any(jnp.isnan(grad)), f"{src_name} {grad_name} contains NaN" + assert not jnp.any(jnp.isinf(grad)), f"{src_name} {grad_name} contains Inf" + check_dqkv(primitive_dq, reference_dq, self.pad_q, 0) check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1) check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2) @@ -1487,6 +1520,87 @@ def test_backward( ) runner.test_backward() +@pytest.mark.skipif( + not is_hip_extension(), reason="Bias all-neg-inf NaN fix is ROCm-specific (SWDEV-561757)" +) +@pytest.mark.parametrize( + "s_kv, h_kv, bias_shape", + [ + pytest.param(1024, 12, BiasShape._B1SS, id="B1SS-SELF"), + pytest.param(512, 12, BiasShape._B1SS, id="B1SS-CROSS"), + pytest.param(1024, 6, BiasShape._1HSS, id="1HSS-GQA"), + pytest.param(1024, 12, BiasShape._BHSS, id="BHSS-SELF"), + pytest.param(1024, 12, BiasShape._11SS, id="11SS-SELF"), + ], +) +@pytest.mark.parametrize( + "b, s_q, h_q, d_qk, d_v, dtype, qkv_layout", + [(2, 1024, 12, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BSHD_BSHD)], +) +def test_backward_bias_all_neg_inf( + b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout, bias_shape +): + """ + Test backward with bias containing true -inf values for all supported bias shapes. + Regression test for SWDEV-561757: when bias rows are ALL -inf (fully masked-out query + positions), softmax produces NaN which propagates into dq/dk/dv. The CK kernel fix and + the reference fix in general_dot_product_attention handle this by zeroing out fully-masked + rows before and after softmax. + + The bias is a binary mask with only two values: 0 and -inf. Some rows are entirely -inf + (fully masked-out), while other rows have a mix of 0 and -inf (partially masked). + """ + runner = FusedAttnRunner( + batch_size=b, + max_seqlen_q=s_q, + max_seqlen_kv=s_kv, + num_heads_q=h_q, + num_heads_kv=h_kv, + head_dim_qk=d_qk, + head_dim_v=d_v, + attn_bias_type=AttnBiasType.POST_SCALE_BIAS, + attn_mask_type=AttnMaskType.NO_MASK, + softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX, + dropout_prob=0.0, + use_old_rng=True, + dtype=dtype, + is_training=True, + qkv_layout=qkv_layout, + bias_shape=bias_shape, + window_size=None, + seq_desc_format=SeqDescFormat.Mask, + ) + runner._setup_inputs() + + if runner.backend != NVTE_Fused_Attn_Backend.NVTE_CK: + pytest.skip("All-neg-inf bias NaN fix is CK-specific") + + # Build a binary bias with only two values: 0 and -inf. + # Shape depends on bias_shape: B1SS=(b,1,s_q,s_kv), 1HSS=(1,h_q,s_q,s_kv), etc. + shape_map = { + BiasShape._B1SS: (b, 1, s_q, s_kv), + BiasShape._1HSS: (1, h_q, s_q, s_kv), + BiasShape._BHSS: (b, h_q, s_q, s_kv), + BiasShape._11SS: (1, 1, s_q, s_kv), + } + concrete_shape = shape_map[bias_shape] + bias = jnp.full(concrete_shape, -jnp.inf, dtype=dtype) + # Use an explicit block-diagonal pattern with guaranteed gaps so that: + # - Rows inside a block have a mix of 0 (within-block columns) and -inf (outside) + # - Rows in the gaps between blocks are entirely -inf (fully masked-out) + block_size = min(s_q, s_kv) // 8 + gap_size = block_size // 2 + pos = 0 + while pos + block_size <= min(s_q, s_kv): + bias = bias.at[:, :, pos : pos + block_size, pos : pos + block_size].set(0.0) + pos += block_size + gap_size # leave a gap of all-neg-inf rows + runner.bias = bias + + # Prevent test_backward from re-running _setup_inputs which would regenerate bias + runner._setup_inputs = lambda: None + runner.test_backward() + + # Single test with new-style RNG @pytest.mark.skipif( not is_hip_extension(), reason="New-style RNGs only enabled on AMD hardware" From ac126171b61c15afe3f174627331b4832c3f3708 Mon Sep 17 00:00:00 2001 From: alextmagro Date: Wed, 29 Apr 2026 11:51:09 -0500 Subject: [PATCH 08/14] Disable all UB layer tests for gfx942 (#567) (cherry picked from commit e0587a92c4aec3be01d1168611bec30f628506c1) --- tests/pytorch/distributed/test_comm_gemm_overlap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 8d98c6263..68671447b 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -105,9 +105,9 @@ def _run_layer_with_overlap( pytest.skip("Bulk overlap is not yet supported on HIP/ROCm.") # On gfx942, non-determinism across the 8 XCDs causes small jitter that compounds # This should not affect training convergence, but creates larger numerical differences. + # TODO: Fix gfx942 issues arising from deterministic bwd attention and other jitter if (IS_HIP_EXTENSION - and get_device_compute_capability() < (9, 5) - and layer_type == te.TransformerLayer.__name__): + and get_device_compute_capability() < (9, 5)): pytest.skip("TransformerLayer overlap can exceed numerical tolerance on pre-MI350 due to jitter.") test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ From 67f6e744a8392049771377fdb0e7d03e9a9b5b50 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 4 May 2026 17:10:47 -0500 Subject: [PATCH 09/14] upgrade hypothesis/setuptools (#572) (cherry picked from commit ed839f4861fc06a2d43cb89f97d492ab08a661a5) --- .github/workflows/rocm-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 9cf8b90cc..f148c2471 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -174,6 +174,7 @@ jobs: pip install --no-deps "$TE_CORE_PKG" pip install ninja pybind11[global] + pip install --upgrade hypothesis setuptools pip install --no-build-isolation --no-deps "$TE_TORCH_PKG" pip install --no-build-isolation --no-deps "$TE_JAX_PKG" EOF @@ -360,6 +361,7 @@ jobs: pip install --no-deps "$TE_CORE_PKG" pip install ninja pybind11[global] + pip install --upgrade hypothesis setuptools pip install --no-build-isolation --no-deps "$TE_FW_PKG" EOF )" From 43d5fcfc9a50c9543051208d6e14307a48162eb6 Mon Sep 17 00:00:00 2001 From: ipanfilo <145064111+ipanfilo@users.noreply.github.com> Date: Fri, 29 May 2026 23:30:41 -0400 Subject: [PATCH 10/14] Install: use setuptools bdist_wheel; CI: call pytest as module (#601) (cherry picked from commit a47087b4e78804a551d0b777947eb6476b1a2f69) --- ci/_utils.sh | 5 +++-- setup.py | 5 ++++- transformer_engine/pytorch/setup.py | 5 ++++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/ci/_utils.sh b/ci/_utils.sh index b4aae9cc7..9c0f4a847 100644 --- a/ci/_utils.sh +++ b/ci/_utils.sh @@ -234,7 +234,7 @@ start_message() { fi _rocm_path=`$REALPATH "$_rocm_path"` test -d "$_rocm_path" && echo "ROCm: $_rocm_path" || echo "ROCm path not found" - python --version + python3 --version } configure_omp_threads() { @@ -266,6 +266,7 @@ pytest_run() { check_test_filter $_test_name_tag || return _start_ts=`date +%s` echo "Run [$_test_variant_tag] $@ at `time_elapsed $TEST_START_TS`" - pytest -v -rfEs `get_pytest_junitxml $_test_name_tag` $TEST_PYTEST_ARGS "$TEST_DIR/$@" || test_run_error "[$_test_variant_tag] $1" + python3 -m pytest -v -rfEs `get_pytest_junitxml $_test_name_tag` $TEST_PYTEST_ARGS "$TEST_DIR/$@" + test $? -eq 0 || test_run_error "[$_test_variant_tag] $1" echo "Done [$_test_variant_tag] $1 in `time_elapsed $_start_ts`" } diff --git a/setup.py b/setup.py index c66af39df..20ec983d3 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,10 @@ import setuptools from setuptools.command.egg_info import egg_info -from wheel.bdist_wheel import bdist_wheel +try: + from setuptools.command.bdist_wheel import bdist_wheel +except ImportError: + from wheel.bdist_wheel import bdist_wheel from build_tools.build_ext import CMakeExtension, get_build_ext from build_tools.te_version import te_version diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index a80dcacb8..616a50f44 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -15,7 +15,10 @@ import platform import urllib import setuptools -from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +try: + from setuptools.command.bdist_wheel import bdist_wheel as _bdist_wheel +except ImportError: + from wheel.bdist_wheel import bdist_wheel as _bdist_wheel from packaging.version import parse try: From b4b46081b0b55ec7a30992f33e9d0d2a84163b48 Mon Sep 17 00:00:00 2001 From: ipanfilo <145064111+ipanfilo@users.noreply.github.com> Date: Tue, 9 Jun 2026 11:23:39 -0400 Subject: [PATCH 11/14] JAX 0.9 compatibility changes (#604) * Do not use deprecated pxla.thread_resources Meshs on JAX 0.9 * Fix typo in FFI target registration (cherry picked from commit 5a5f7dad0cf0aca355a8e68c1c490d0ad387f6f4) --- .../attention/benchmark_attention_jax.py | 4 +- tests/jax/test_distributed_dense.py | 6 +- tests/jax/test_distributed_layernorm.py | 6 +- tests/jax/test_distributed_layernorm_mlp.py | 4 +- tests/jax/test_distributed_softmax.py | 4 +- tests/jax/test_fused_attn.py | 8 +-- transformer_engine/jax/sharding.py | 62 ++++++++++++++++--- 7 files changed, 74 insertions(+), 20 deletions(-) diff --git a/benchmarks/attention/benchmark_attention_jax.py b/benchmarks/attention/benchmark_attention_jax.py index 54dd28505..b0b60cdb8 100644 --- a/benchmarks/attention/benchmark_attention_jax.py +++ b/benchmarks/attention/benchmark_attention_jax.py @@ -136,7 +136,7 @@ def bench_forward(self, warmup, iters, timings_dir): self.dropout_rng_sharding, ], ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource): for _ in range(warmup): customcall_fused_dpa_jit(*customcall_args) @@ -227,7 +227,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): ), out_shardings=(None, grad_shardings), ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource): for _ in range(warmup): jitted_primitive(*customcall_args) diff --git a/tests/jax/test_distributed_dense.py b/tests/jax/test_distributed_dense.py index b8caf188d..818298ed8 100644 --- a/tests/jax/test_distributed_dense.py +++ b/tests/jax/test_distributed_dense.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -127,7 +129,7 @@ def test_distributed_gemm( contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension - with mesh, autocast(enabled=False, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource): # TE GEMM result te_result = _jitted_gemm( x_sharded, @@ -209,7 +211,7 @@ def test_te_distributed_dense_grad( contracting_dims = ((2,), (0,)) - with mesh, autocast(enabled=False, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource): # Test gradients w.r.t. all inputs te_grad_func = jax.jit( jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)), diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 21359cedf..eb4497a0a 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -135,7 +137,7 @@ def ref_func(x, gamma, beta): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) g_named_sharding = NamedSharding(mesh, g_pspec) b_named_sharding = NamedSharding(mesh, b_pspec) @@ -217,7 +219,7 @@ def ref_func(x, gamma): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) g_named_sharding = NamedSharding(mesh, g_pspec) x_ = jax.device_put(x, x_named_sharding) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 6a2f395b1..4ed9e3cf5 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -261,7 +261,7 @@ def _test_layernorm_mlp_grad( # Multi GPUs devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast( + with jax.set_mesh(mesh), autocast( enabled=quantization_recipe is not None, recipe=quantization_recipe, mesh_resource=mesh_resource, @@ -452,7 +452,7 @@ def _test_layernorm_mlp( device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast( + with jax.set_mesh(mesh), autocast( enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource ): ln_mlp_sharded = LayerNormMLP( diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 0665baa4e..ff44f249c 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -109,7 +111,7 @@ def impl_test_softmax( collective_count_ref = self.generate_collectives_count_ref() devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) mask_named_sharding = NamedSharding(mesh, mask_pspec) x_ = jax.device_put(x, x_named_sharding) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 68a855161..2b69a397b 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -914,7 +914,7 @@ def test_forward(self): ], ) - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): primitive_out = customcall_fused_dpa_jit(*customcall_args) primitive_out = self.cp_inverse_reorder_fn(primitive_out) @@ -931,7 +931,7 @@ def test_forward(self): assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) if self.coll_count_ref is not None: - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): target_hlo = ( customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text() ) @@ -1045,7 +1045,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): ) ) - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) reference_out, reference_dgrad = jitted_reference(*args) @@ -1133,7 +1133,7 @@ def check_dqkv(primitive, reference, pad, idx): ) if self.coll_count_ref is not None: - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 9b13412c1..a3df47f11 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,6 +13,7 @@ """ from contextlib import contextmanager from dataclasses import dataclass +from packaging import version from typing import Callable, Optional import warnings @@ -20,7 +23,8 @@ from jax.sharding import PartitionSpec, get_abstract_mesh import numpy as np -_PXLA_THREAD_RESOURCES = pxla.thread_resources +if version.parse(jax.__version__) < version.parse("0.9.0"): + _PXLA_THREAD_RESOURCES = pxla.thread_resources # Axis Names BATCH_AXES = "nvte_batch" @@ -39,9 +43,11 @@ def _get_mesh(): # Handle Mesh's set via `with mesh:` - mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh - if mesh is not None and not mesh.empty: - return mesh + # ROCm: add JAX version guard for all backends + if version.parse(jax.__version__) < version.parse("0.9.0"): + mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh + if mesh is not None and not mesh.empty: + return mesh # Handle Mesh's set via `jax.set_mesh(mesh)` return jax.sharding.get_abstract_mesh() @@ -164,6 +170,31 @@ def filter_manual_axes(name_or_tuple): return x cleaned_pspec = PartitionSpec(*cleaned_axis_names) + + # ROCm: JAX 0.9 compat (all backends) — when an AbstractMesh is active, + # jax.lax.with_sharding_constraint requires the input to already carry a + # NamedSharding. This affects both concrete arrays in eager mode and traced + # values inside jax.jit whose abstract sharding is not a NamedSharding (e.g. + # Module.init() traces over a single-device input and JAX propagates the + # SingleDeviceSharding through the Tracer). In both cases the constraint must + # be skipped because JAX raises unconditionally. + # A UserWarning is emitted only for concrete (non-Tracer) arrays so the user + # gets a visible signal in eager mode; the jit-traced skip is unavoidable and + # kept silent to avoid spurious warnings from traced code. + if hasattr(x, "sharding") and not isinstance(x.sharding, jax.sharding.NamedSharding): + if not isinstance(x, jax.core.Tracer): + warnings.warn( + f"with_sharding_constraint: the sharding constraint {cleaned_pspec!r} was not" + f" applied because the input array carries a {type(x.sharding).__name__} rather" + " than a NamedSharding. This typically happens in eager mode when arrays have not" + " yet been placed on a mesh (e.g. during model initialisation). Wrap the call in" + " jax.jit or ensure the array is on a named mesh before applying sharding" + " constraints.", + UserWarning, + stacklevel=2, + ) + return x + return jax.lax.with_sharding_constraint(x, cleaned_pspec) @@ -359,6 +390,14 @@ def global_shard_guard(resource: MeshResource): old_resources = _GLOBAL_MESH_RESOURCE try: _GLOBAL_MESH_RESOURCE = resource + # ROCm: JAX 0.9 compat (all backends) + # Validate once at context-setup time, where get_abstract_mesh() correctly + # reflects the physical mesh. Calling _validate_mesh_resource_configuration + # from global_mesh_resource() (i.e. on every access) breaks in JAX 0.9 + # because get_abstract_mesh() returns an empty AbstractMesh when called + # from inside a custom_partitioning sharded_impl during jit(...).lower(). + if resource is not None: + _validate_mesh_resource_configuration(resource) yield finally: _GLOBAL_MESH_RESOURCE = old_resources @@ -375,7 +414,13 @@ def global_mesh_resource() -> MeshResource: " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) - _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) + # ROCm: JAX 0.9 compat (all backends) + # _validate_mesh_resource_configuration is intentionally NOT called here. + # Validation is done once in global_shard_guard() at context-setup time, where + # get_abstract_mesh() correctly reflects the physical mesh. Calling it here + # would break in JAX 0.9 when global_mesh_resource() is invoked from inside a + # custom_partitioning sharded_impl during jit(...).lower(), at which point + # get_abstract_mesh() returns an empty AbstractMesh. return _GLOBAL_MESH_RESOURCE @@ -418,8 +463,11 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes Returns: Reduced tensor """ - all_axes = get_all_mesh_axes() - for axis in all_axes: + # ROCm: JAX 0.9 compat (all backends) + # Use mesh.axis_names from the concrete mesh argument rather than calling + # get_all_mesh_axes() → _get_mesh() → get_abstract_mesh(), which returns + # empty in JAX 0.9 when called from inside a custom_partitioning sharded_impl. + for axis in mesh.axis_names: if axis != global_mesh_resource().pp_resource: x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x From 306c8ace7495f72d68db5d690036a3177dc7ef51 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Tue, 16 Jun 2026 11:31:38 -0400 Subject: [PATCH 12/14] Hotfix for Maxtext regression with JAX 0.9 changes (PR#629) (cherry picked from commit 5e7bf042e45d50487a759c3cc7c69a048c85dc6f) --- transformer_engine/jax/sharding.py | 49 ++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index a3df47f11..5691280aa 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -374,6 +374,10 @@ class MeshResource: _GLOBAL_MESH_RESOURCE = None +# ROCm: True once _validate_mesh_resource_configuration has successfully run for the +# current _GLOBAL_MESH_RESOURCE. Reset to False on every global_shard_guard +# entry so that a new resource is always (re-)validated on first use. +_GLOBAL_MESH_RESOURCE_VALIDATED = False @contextmanager @@ -386,21 +390,29 @@ def global_shard_guard(resource: MeshResource): Args: resource: MeshResource instance defining the sharding configuration """ - global _GLOBAL_MESH_RESOURCE + global _GLOBAL_MESH_RESOURCE, _GLOBAL_MESH_RESOURCE_VALIDATED old_resources = _GLOBAL_MESH_RESOURCE + old_validated = _GLOBAL_MESH_RESOURCE_VALIDATED try: _GLOBAL_MESH_RESOURCE = resource # ROCm: JAX 0.9 compat (all backends) - # Validate once at context-setup time, where get_abstract_mesh() correctly - # reflects the physical mesh. Calling _validate_mesh_resource_configuration - # from global_mesh_resource() (i.e. on every access) breaks in JAX 0.9 - # because get_abstract_mesh() returns an empty AbstractMesh when called - # from inside a custom_partitioning sharded_impl during jit(...).lower(). - if resource is not None: + # Attempt early (eager) validation if a mesh is already active at + # guard-entry time. Guard with is_mesh_available() so that callers + # who enter global_shard_guard before any JAX mesh context is active + # (e.g. maxtext's transformer_engine_context) do not hit an + # AssertionError in get_mesh_axis_size() when get_abstract_mesh() + # returns an empty OrderedDict(). + # Reset the validated flag for the new resource so that + # global_mesh_resource() re-validates on its first call with an + # active mesh (lazy validation path, see below). + _GLOBAL_MESH_RESOURCE_VALIDATED = False + if resource is not None and is_mesh_available(): _validate_mesh_resource_configuration(resource) + _GLOBAL_MESH_RESOURCE_VALIDATED = True yield finally: _GLOBAL_MESH_RESOURCE = old_resources + _GLOBAL_MESH_RESOURCE_VALIDATED = old_validated def global_mesh_resource() -> MeshResource: @@ -409,18 +421,29 @@ def global_mesh_resource() -> MeshResource: Returns: The current MeshResource instance """ + global _GLOBAL_MESH_RESOURCE_VALIDATED assert _GLOBAL_MESH_RESOURCE is not None, ( "Global mesh resource is not set. Please set the MeshResource via a global_shard_guard" " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) # ROCm: JAX 0.9 compat (all backends) - # _validate_mesh_resource_configuration is intentionally NOT called here. - # Validation is done once in global_shard_guard() at context-setup time, where - # get_abstract_mesh() correctly reflects the physical mesh. Calling it here - # would break in JAX 0.9 when global_mesh_resource() is invoked from inside a - # custom_partitioning sharded_impl during jit(...).lower(), at which point - # get_abstract_mesh() returns an empty AbstractMesh. + # Lazy validation: if the mesh was not yet active when global_shard_guard + # was entered (eager validation skipped), validate here on the first call + # that actually finds an active mesh. This covers frameworks like maxtext + # that set up global_shard_guard before activating the JAX mesh context. + # + # The _GLOBAL_MESH_RESOURCE_VALIDATED flag ensures validation runs at most + # once per global_shard_guard context (reset to False on guard entry, + # set to True after successful validation): + # • After validation: `not _GLOBAL_MESH_RESOURCE_VALIDATED` is False → + # only one boolean check per call, faster than the pre-JAX-0.9-compat + # code that ran get_mesh_axis_size() unconditionally on every call. + # • Inside jit(...).lower(): is_mesh_available() returns False (JAX 0.9 + # get_abstract_mesh() is empty there) → validation safely skipped. + if not _GLOBAL_MESH_RESOURCE_VALIDATED and is_mesh_available(): + _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) + _GLOBAL_MESH_RESOURCE_VALIDATED = True return _GLOBAL_MESH_RESOURCE From 46f17bac9e7d1bcebf4f0a40b6ec99ed6b2ac69e Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Wed, 17 Jun 2026 04:09:37 -0400 Subject: [PATCH 13/14] Add test (PR#629) (cherry picked from commit ef79328b3d6ff0f9eb7181fb83ede8d835dff0f5) --- ci/jax.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/jax.sh b/ci/jax.sh index 4804ecff3..bfd2ebde3 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -81,9 +81,10 @@ run_test_config_mgpu() { export NVTE_JAX_UNITTEST_LEVEL=L2 fi - run_default_fa 2 test_distributed_dense.py + run_default_fa 1 test_distributed_dense.py # RCCL_MSCCL_ENABLE=0 is to avoid hangs in some distributed tests (ROCM-1719) RCCL_MSCCL_ENABLE=0 run $_dfa_level test_distributed_fused_attn.py $_timeout_args + run_default_fa 1 test_distributed_helper.py run_default_fa 3 test_distributed_layernorm.py run_default_fa 2 test_distributed_layernorm_mlp.py $_timeout_args run_default_fa 3 test_distributed_softmax.py From 4ad4650c99e85c6fda2cd53bc9b9b79b97e6b896 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Wed, 17 Jun 2026 13:41:50 -0400 Subject: [PATCH 14/14] Remove wrong submodule reference --- .github/workflows/rocm-wheels-build.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/rocm-wheels-build.yml b/.github/workflows/rocm-wheels-build.yml index c1a8ea087..81ff36798 100644 --- a/.github/workflows/rocm-wheels-build.yml +++ b/.github/workflows/rocm-wheels-build.yml @@ -86,7 +86,6 @@ jobs: git submodule update --init --recursive --depth 1 \ 3rdparty/aotriton \ 3rdparty/aiter \ - 3rdparty/QoLA \ 3rdparty/hipify_torch - name: Derive Docker image tag