Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ci/jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ run_test_config_mgpu() {
export NVTE_JAX_UNITTEST_LEVEL=L2
fi

run_default_fa 2 test_distributed_dense.py
run_default_fa 1 test_distributed_dense.py
# RCCL_MSCCL_ENABLE=0 is to avoid hangs in some distributed tests (ROCM-1719)
RCCL_MSCCL_ENABLE=0 run $_dfa_level test_distributed_fused_attn.py $_timeout_args
run_default_fa 1 test_distributed_helper.py
run_default_fa 3 test_distributed_layernorm.py
run_default_fa 2 test_distributed_layernorm_mlp.py $_timeout_args
run_default_fa 3 test_distributed_softmax.py
Expand Down
49 changes: 36 additions & 13 deletions transformer_engine/jax/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ class MeshResource:


_GLOBAL_MESH_RESOURCE = None
# ROCm: True once _validate_mesh_resource_configuration has successfully run for the
# current _GLOBAL_MESH_RESOURCE. Reset to False on every global_shard_guard
# entry so that a new resource is always (re-)validated on first use.
_GLOBAL_MESH_RESOURCE_VALIDATED = False


@contextmanager
Expand All @@ -386,21 +390,29 @@ def global_shard_guard(resource: MeshResource):
Args:
resource: MeshResource instance defining the sharding configuration
"""
global _GLOBAL_MESH_RESOURCE
global _GLOBAL_MESH_RESOURCE, _GLOBAL_MESH_RESOURCE_VALIDATED
old_resources = _GLOBAL_MESH_RESOURCE
old_validated = _GLOBAL_MESH_RESOURCE_VALIDATED
try:
_GLOBAL_MESH_RESOURCE = resource
# ROCm: JAX 0.9 compat (all backends)
# Validate once at context-setup time, where get_abstract_mesh() correctly
# reflects the physical mesh. Calling _validate_mesh_resource_configuration
# from global_mesh_resource() (i.e. on every access) breaks in JAX 0.9
# because get_abstract_mesh() returns an empty AbstractMesh when called
# from inside a custom_partitioning sharded_impl during jit(...).lower().
if resource is not None:
# Attempt early (eager) validation if a mesh is already active at
# guard-entry time. Guard with is_mesh_available() so that callers
# who enter global_shard_guard before any JAX mesh context is active
# (e.g. maxtext's transformer_engine_context) do not hit an
# AssertionError in get_mesh_axis_size() when get_abstract_mesh()
# returns an empty OrderedDict().
# Reset the validated flag for the new resource so that
# global_mesh_resource() re-validates on its first call with an
# active mesh (lazy validation path, see below).
_GLOBAL_MESH_RESOURCE_VALIDATED = False
if resource is not None and is_mesh_available():
_validate_mesh_resource_configuration(resource)
_GLOBAL_MESH_RESOURCE_VALIDATED = True
yield
finally:
_GLOBAL_MESH_RESOURCE = old_resources
_GLOBAL_MESH_RESOURCE_VALIDATED = old_validated


def global_mesh_resource() -> MeshResource:
Expand All @@ -409,18 +421,29 @@ def global_mesh_resource() -> MeshResource:
Returns:
The current MeshResource instance
"""
global _GLOBAL_MESH_RESOURCE_VALIDATED
assert _GLOBAL_MESH_RESOURCE is not None, (
"Global mesh resource is not set. Please set the MeshResource via a global_shard_guard"
" context. If you are not using multiple GPUs, you can use an empty MeshResource by"
" wrapping your program in 'with global_shard_guard(MeshResource()):'"
)
# ROCm: JAX 0.9 compat (all backends)
# _validate_mesh_resource_configuration is intentionally NOT called here.
# Validation is done once in global_shard_guard() at context-setup time, where
# get_abstract_mesh() correctly reflects the physical mesh. Calling it here
# would break in JAX 0.9 when global_mesh_resource() is invoked from inside a
# custom_partitioning sharded_impl during jit(...).lower(), at which point
# get_abstract_mesh() returns an empty AbstractMesh.
# Lazy validation: if the mesh was not yet active when global_shard_guard
# was entered (eager validation skipped), validate here on the first call
# that actually finds an active mesh. This covers frameworks like maxtext
# that set up global_shard_guard before activating the JAX mesh context.
#
# The _GLOBAL_MESH_RESOURCE_VALIDATED flag ensures validation runs at most
# once per global_shard_guard context (reset to False on guard entry,
# set to True after successful validation):
# • After validation: `not _GLOBAL_MESH_RESOURCE_VALIDATED` is False →
# only one boolean check per call, faster than the pre-JAX-0.9-compat
# code that ran get_mesh_axis_size() unconditionally on every call.
# • Inside jit(...).lower(): is_mesh_available() returns False (JAX 0.9
# get_abstract_mesh() is empty there) → validation safely skipped.
if not _GLOBAL_MESH_RESOURCE_VALIDATED and is_mesh_available():
_validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE)
_GLOBAL_MESH_RESOURCE_VALIDATED = True
Comment thread
ipanfilo marked this conversation as resolved.
return _GLOBAL_MESH_RESOURCE


Expand Down
Loading