From 5e7bf042e45d50487a759c3cc7c69a048c85dc6f Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Tue, 16 Jun 2026 11:31:38 -0400 Subject: [PATCH 1/2] Hotfix for Maxtext regression with JAX 0.9 changes --- transformer_engine/jax/sharding.py | 49 ++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index a3df47f11..5691280aa 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -374,6 +374,10 @@ class MeshResource: _GLOBAL_MESH_RESOURCE = None +# ROCm: True once _validate_mesh_resource_configuration has successfully run for the +# current _GLOBAL_MESH_RESOURCE. Reset to False on every global_shard_guard +# entry so that a new resource is always (re-)validated on first use. +_GLOBAL_MESH_RESOURCE_VALIDATED = False @contextmanager @@ -386,21 +390,29 @@ def global_shard_guard(resource: MeshResource): Args: resource: MeshResource instance defining the sharding configuration """ - global _GLOBAL_MESH_RESOURCE + global _GLOBAL_MESH_RESOURCE, _GLOBAL_MESH_RESOURCE_VALIDATED old_resources = _GLOBAL_MESH_RESOURCE + old_validated = _GLOBAL_MESH_RESOURCE_VALIDATED try: _GLOBAL_MESH_RESOURCE = resource # ROCm: JAX 0.9 compat (all backends) - # Validate once at context-setup time, where get_abstract_mesh() correctly - # reflects the physical mesh. Calling _validate_mesh_resource_configuration - # from global_mesh_resource() (i.e. on every access) breaks in JAX 0.9 - # because get_abstract_mesh() returns an empty AbstractMesh when called - # from inside a custom_partitioning sharded_impl during jit(...).lower(). - if resource is not None: + # Attempt early (eager) validation if a mesh is already active at + # guard-entry time. Guard with is_mesh_available() so that callers + # who enter global_shard_guard before any JAX mesh context is active + # (e.g. maxtext's transformer_engine_context) do not hit an + # AssertionError in get_mesh_axis_size() when get_abstract_mesh() + # returns an empty OrderedDict(). + # Reset the validated flag for the new resource so that + # global_mesh_resource() re-validates on its first call with an + # active mesh (lazy validation path, see below). + _GLOBAL_MESH_RESOURCE_VALIDATED = False + if resource is not None and is_mesh_available(): _validate_mesh_resource_configuration(resource) + _GLOBAL_MESH_RESOURCE_VALIDATED = True yield finally: _GLOBAL_MESH_RESOURCE = old_resources + _GLOBAL_MESH_RESOURCE_VALIDATED = old_validated def global_mesh_resource() -> MeshResource: @@ -409,18 +421,29 @@ def global_mesh_resource() -> MeshResource: Returns: The current MeshResource instance """ + global _GLOBAL_MESH_RESOURCE_VALIDATED assert _GLOBAL_MESH_RESOURCE is not None, ( "Global mesh resource is not set. Please set the MeshResource via a global_shard_guard" " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) # ROCm: JAX 0.9 compat (all backends) - # _validate_mesh_resource_configuration is intentionally NOT called here. - # Validation is done once in global_shard_guard() at context-setup time, where - # get_abstract_mesh() correctly reflects the physical mesh. Calling it here - # would break in JAX 0.9 when global_mesh_resource() is invoked from inside a - # custom_partitioning sharded_impl during jit(...).lower(), at which point - # get_abstract_mesh() returns an empty AbstractMesh. + # Lazy validation: if the mesh was not yet active when global_shard_guard + # was entered (eager validation skipped), validate here on the first call + # that actually finds an active mesh. This covers frameworks like maxtext + # that set up global_shard_guard before activating the JAX mesh context. + # + # The _GLOBAL_MESH_RESOURCE_VALIDATED flag ensures validation runs at most + # once per global_shard_guard context (reset to False on guard entry, + # set to True after successful validation): + # • After validation: `not _GLOBAL_MESH_RESOURCE_VALIDATED` is False → + # only one boolean check per call, faster than the pre-JAX-0.9-compat + # code that ran get_mesh_axis_size() unconditionally on every call. + # • Inside jit(...).lower(): is_mesh_available() returns False (JAX 0.9 + # get_abstract_mesh() is empty there) → validation safely skipped. + if not _GLOBAL_MESH_RESOURCE_VALIDATED and is_mesh_available(): + _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) + _GLOBAL_MESH_RESOURCE_VALIDATED = True return _GLOBAL_MESH_RESOURCE From ef79328b3d6ff0f9eb7181fb83ede8d835dff0f5 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Wed, 17 Jun 2026 04:09:37 -0400 Subject: [PATCH 2/2] Add test --- ci/jax.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/jax.sh b/ci/jax.sh index 574376317..ab883e9e9 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -82,9 +82,10 @@ run_test_config_mgpu() { export NVTE_JAX_UNITTEST_LEVEL=L2 fi - run_default_fa 2 test_distributed_dense.py + run_default_fa 1 test_distributed_dense.py # RCCL_MSCCL_ENABLE=0 is to avoid hangs in some distributed tests (ROCM-1719) RCCL_MSCCL_ENABLE=0 run $_dfa_level test_distributed_fused_attn.py $_timeout_args + run_default_fa 1 test_distributed_helper.py run_default_fa 3 test_distributed_layernorm.py run_default_fa 2 test_distributed_layernorm_mlp.py $_timeout_args run_default_fa 3 test_distributed_softmax.py