diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py new file mode 100644 index 0000000000..1c7b99cda4 --- /dev/null +++ b/tests/jax/test_distributed_moe_block.py @@ -0,0 +1,170 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed tests for ``transformer_engine.jax.flax.MoEBlock``.""" + +import sys + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from jax.sharding import Mesh, PartitionSpec + +from utils import assert_allclose, is_devices_enough + + +@pytest.fixture(autouse=True, scope="function") +def _inject_moe(request): + """Lazy-load ``MoEBlock`` only for tests marked ``triton``.""" + if not request.node.get_closest_marker("triton"): + yield + return + + from transformer_engine.jax import MeshResource, autocast + from transformer_engine.jax.flax import MoEBlock + + mod = sys.modules[__name__] + mod.MeshResource = MeshResource + mod.autocast = autocast + mod.MoEBlock = MoEBlock + yield + + +DTYPE = jnp.bfloat16 +BATCH_SIZE = 2 +SEQUENCE_LENGTH = 16 +HIDDEN_SIZE = 64 +INTERMEDIATE_SIZE = 128 +NUM_EXPERTS = 8 +NUM_EXPERTS_PER_TOK = 2 + + +def _make_inputs(key: jax.Array) -> jax.Array: + return jax.random.normal( + key, (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=DTYPE + ) + + +def _unwrap_partitioned(x): + return x.value if hasattr(x, "value") else x + + +@pytest.mark.triton +class TestDistributedMoEBlock: + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_ep2_fsdp2_matches_single_device(self, permutation_backend): + if not is_devices_enough(4): + pytest.skip("MoE distributed test requires 4 devices for EP=2 x FSDP=2.") + + key = jax.random.PRNGKey(11) + init_key, data_key = jax.random.split(key) + inputs = _make_inputs(data_key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + aux_loss_coeff=1e-2, + dtype=DTYPE, + ) + + single_block = MoEBlock(**base_kwargs) + + def loss_fn(block, variables, x): + output, aux_loss = block.apply(variables, x) + loss = jnp.mean(output.astype(jnp.float32) ** 2) + if aux_loss is not None: + loss = loss + aux_loss.astype(jnp.float32) + return loss, (output, aux_loss) + + with autocast(enabled=False, mesh_resource=MeshResource()): + single_variables = single_block.init(init_key, inputs) + (single_loss, (single_output, single_aux)), single_grads = jax.value_and_grad( + loss_fn, argnums=1, has_aux=True + )(single_block, single_variables, inputs) + + devices = np.asarray(jax.devices()[:4]).reshape(2, 2) + mesh = Mesh(devices, ("ep", "fsdp")) + # FSDP-style sharding: weights are sharded on a *non-contracting* + # weight axis (gathered before the GEMM); activations stay sharded on + # the *batch* axis throughout - the same fsdp mesh axis is reused for + # both. The TE primitives' custom_partitioning rules expect activations + # FSDP-sharded on batch, so we declare ("batch", "fsdp") AND pass + # ``input_axes=("batch", None, None)`` to enforce it on the inputs to + # the block. ("embed", "fsdp") shards the weight's hidden dim, which + # is gathered inside grouped_dense's custom_partitioning before GEMM + # (no reshard of activations needed because their layout is unchanged). + logical_axis_rules = ( + ("exp", "ep"), + ("batch", "fsdp"), + ("embed", "fsdp"), + ) + sharded_block = MoEBlock( + expert_parallelism_axis="ep", + mesh=mesh, + input_axes=("batch", None, None), + **base_kwargs, + ) + + with mesh, autocast(enabled=False, mesh_resource=MeshResource(fsdp_resource="fsdp")): + with nn.logical_axis_rules(logical_axis_rules): + # ``MoEBlock`` registers params via ``with_logical_partitioning`` + # which only attaches LogicallyPartitioned metadata; the + # underlying jax.Array stays single-device unless ``init`` + # is run inside ``jax.jit`` with ``out_shardings``. Use the + # canonical Flax-Linen pattern (mirrors + # ``examples/jax/encoder/test_model_parallel_encoder.py``): + # 1. ``jax.eval_shape`` to trace abstract variables (keeps + # the LogicallyPartitioned wrappers; only the inner + # arrays become ShapeDtypeStruct); + # 2. ``nn.get_partition_spec`` to extract a tree of logical + # PartitionSpecs from those wrappers (treats + # LogicallyPartitioned as a leaf); + # 3. ``nn.logical_to_mesh_sharding`` to resolve those + # logical specs to NamedShardings via the active rules; + # 4. ``jax.jit(init, out_shardings=...)`` to actually + # place the params on-device with those shardings. + abstract_variables = jax.eval_shape( + sharded_block.init, init_key, inputs + ) + logical_partition_spec = nn.get_partition_spec( + abstract_variables + ) + out_shardings = nn.logical_to_mesh_sharding( + logical_partition_spec, mesh, logical_axis_rules + ) + sharded_variables = jax.jit( + sharded_block.init, out_shardings=out_shardings + )(init_key, inputs) + (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = ( + jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( + sharded_block, sharded_variables, inputs + ) + ) + + wi_0 = _unwrap_partitioned(sharded_variables["params"]["wi_0"]) + wi_1 = _unwrap_partitioned(sharded_variables["params"]["wi_1"]) + wo = _unwrap_partitioned(sharded_variables["params"]["wo"]) + assert wi_0.sharding.spec == PartitionSpec("ep", "fsdp", None) + assert wi_1.sharding.spec == PartitionSpec("ep", "fsdp", None) + assert wo.sharding.spec == PartitionSpec("ep", None, "fsdp") + + assert_allclose(sharded_output, single_output, dtype=DTYPE, atol=5e-2, rtol=5e-2) + assert_allclose(sharded_loss, single_loss, dtype=jnp.float32, atol=5e-2, rtol=5e-2) + assert_allclose(sharded_aux, single_aux, dtype=jnp.float32, atol=5e-2, rtol=5e-2) + + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + grad_single = _unwrap_partitioned(single_grads["params"][name]) + grad_sharded = _unwrap_partitioned(sharded_grads["params"][name]) + assert_allclose( + grad_sharded, + grad_single, + dtype=DTYPE, + atol=1e-1, + rtol=1e-1, + err_msg=f"Distributed gradient mismatch for {name}", + ) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py new file mode 100644 index 0000000000..39a6bfd592 --- /dev/null +++ b/tests/jax/test_moe_block.py @@ -0,0 +1,313 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Basic tests for ``transformer_engine.jax.flax.MoEBlock``. + +These tests exercise the MoEBlock on a single device (no expert parallelism) +and verify: + +* Forward pass runs end-to-end and produces the expected output shape. +* Backward pass yields finite, non-trivial parameter gradients. +* The two permutation backends (``"pure_jax"`` and ``"triton"``) produce + numerically equivalent outputs and gradients when given the same routing + decisions. +* Auxiliary load-balancing loss is returned when ``aux_loss_coeff > 0``. +* DeepSeek-style grouped top-k (``num_groups`` / ``group_topk``) runs. +* ``align_size > 0`` produces numerically-equivalent outputs to ``align_size = 0`` + for the pure-JAX backend (padding must not change the result). +""" + +import sys +from typing import Tuple + +import jax +import jax.numpy as jnp +import pytest + + +# The MoEBlock pulls in both the fused-router CUDA kernel and the Triton +# permutation kernels, so it can only run in the environment where those are +# available. We gate the test on the ``triton`` marker (the Triton permutation +# backend is stricter than the CUDA router). See ``conftest.py``. + + +@pytest.fixture(autouse=True, scope="function") +def _inject_moe(request): + """Lazy-load ``MoEBlock`` only for tests marked ``triton``.""" + if not request.node.get_closest_marker("triton"): + yield + return + + from transformer_engine.jax.flax import MoEBlock + + mod = sys.modules[__name__] + mod.MoEBlock = MoEBlock + yield + + +# ----------------------------------------------------------------------------- +# Configurations +# ----------------------------------------------------------------------------- +# +# Keep shapes small so the tests are cheap but still exercise every code path. + +DTYPE = jnp.bfloat16 +BATCH_SIZE = 2 +SEQUENCE_LENGTH = 16 +HIDDEN_SIZE = 64 +INTERMEDIATE_SIZE = 128 +NUM_EXPERTS = 8 +NUM_EXPERTS_PER_TOK = 2 + + +def _make_inputs( + key: jax.Array, batch_size: int = BATCH_SIZE, sequence_length: int = SEQUENCE_LENGTH +) -> jax.Array: + return jax.random.normal( + key, (batch_size, sequence_length, HIDDEN_SIZE), dtype=DTYPE + ) + + +def _init_and_apply( + block, + inputs: jax.Array, + init_key: jax.Array, +) -> Tuple[dict, jax.Array, jax.Array]: + variables = block.init(init_key, inputs) + output, aux_loss = block.apply(variables, inputs) + return variables, output, aux_loss + + +def _unwrap_partitioned(x): + """Strip Flax logical-partition wrappers for numeric assertions.""" + return x.value if hasattr(x, "value") else x + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + + +@pytest.mark.triton +class TestMoEBlockSingleDevice: + """Single-device smoke tests for :class:`MoEBlock`.""" + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_forward_shape_and_finite(self, permutation_backend): + key = jax.random.PRNGKey(0) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, aux_loss = _init_and_apply(block, inputs, init_key) + + assert output.shape == inputs.shape, ( + f"Unexpected output shape {output.shape} for backend {permutation_backend}" + ) + assert output.dtype == inputs.dtype + assert jnp.all(jnp.isfinite(output)), "Output contains NaN/Inf" + assert aux_loss is None, "aux_loss should be None when aux_loss_coeff=0" + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_backward_grad(self, permutation_backend): + key = jax.random.PRNGKey(1) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + variables = block.init(init_key, inputs) + + def loss_fn(variables, inputs): + output, _ = block.apply(variables, inputs) + return jnp.mean(output.astype(jnp.float32) ** 2) + + grads = jax.grad(loss_fn)(variables, inputs) + # All trainable kernels should receive a non-trivial gradient. + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g = _unwrap_partitioned(grads["params"][name]) + assert jnp.all(jnp.isfinite(g)), f"{name} gradient has NaN/Inf" + assert jnp.any(g != 0.0), f"{name} gradient is identically zero" + + def test_pure_jax_triton_equivalence(self): + """Both permutation backends must produce the same forward + grads + under identical routing decisions. + + Since the two backends share the same routing path (TE's fused + top-k), fixing the gate kernel gives both the same routing decisions + and the remainder of the network is identical modulo the permutation + implementation, whose semantics are equivalent. + """ + key = jax.random.PRNGKey(2) + init_key, data_key = jax.random.split(key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + dtype=DTYPE, + ) + pure_block = MoEBlock(permutation_backend="pure_jax", **base_kwargs) + triton_block = MoEBlock(permutation_backend="triton", **base_kwargs) + inputs = _make_inputs(data_key) + + # Share a single parameter tree so routing decisions and expert + # weights are identical for both backends. + variables = pure_block.init(init_key, inputs) + + def loss_fn(block, variables, inputs): + output, _ = block.apply(variables, inputs) + return jnp.mean(output.astype(jnp.float32) ** 2), output + + (loss_pj, out_pj), grads_pj = jax.value_and_grad( + loss_fn, argnums=1, has_aux=True + )(pure_block, variables, inputs) + (loss_tr, out_tr), grads_tr = jax.value_and_grad( + loss_fn, argnums=1, has_aux=True + )(triton_block, variables, inputs) + + # BF16 tolerances: outputs come out of the grouped-GEMM + weighted + # sum so they accumulate error; we use ~2 ULPs worth of slack. + atol_out, rtol_out = 5e-2, 5e-2 + assert jnp.allclose(out_pj, out_tr, atol=atol_out, rtol=rtol_out), ( + f"Forward outputs differ across backends: max diff" + f" {jnp.max(jnp.abs(out_pj - out_tr))}" + ) + assert jnp.allclose(loss_pj, loss_tr, atol=atol_out, rtol=rtol_out) + + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g_pj = _unwrap_partitioned(grads_pj["params"][name]) + g_tr = _unwrap_partitioned(grads_tr["params"][name]) + assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), ( + f"Gradient for {name} differs across backends: max diff" + f" {jnp.max(jnp.abs(g_pj - g_tr))}" + ) + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_aux_loss_returned(self, permutation_backend): + key = jax.random.PRNGKey(3) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + aux_loss_coeff=1e-2, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, aux_loss = _init_and_apply(block, inputs, init_key) + + assert output.shape == inputs.shape + assert aux_loss is not None, "aux_loss should be returned when coeff > 0" + assert aux_loss.shape == (), "aux_loss should be a scalar" + assert jnp.isfinite(aux_loss) + # With uniform-ish routing the loss should be small-positive, not huge. + assert jnp.abs(aux_loss) < 1e2 + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_group_topk_deepseek(self, permutation_backend): + """Exercise DeepSeek-style grouped top-k routing.""" + key = jax.random.PRNGKey(4) + init_key, data_key = jax.random.split(key) + + # num_groups must divide num_experts. + num_groups = 4 + group_topk = 2 + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + score_function="sigmoid", + num_groups=num_groups, + group_topk=group_topk, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, _aux_loss = _init_and_apply(block, inputs, init_key) + + assert output.shape == inputs.shape + assert jnp.all(jnp.isfinite(output)) + + @pytest.mark.xfail( + reason=( + "TE grouped_dense FFI asserts sum(group_sizes) == M at " + "transformer_engine/jax/csrc/extensions/gemm.cpp:1029. With " + "align_size > 0 both backends produce a buffer where M >= " + "sum(group_sizes) (the slack is structural padding for JIT). " + "The kernel itself iterates over per-expert m_i from " + "group_sizes via nvte_multi_tensor_gemm and never reads past " + "sum(group_sizes), so relaxing that assertion to " + "`m >= sum_group_sizes` is the cleanest fix. The MoE block " + "deliberately does not fold the gap into a single expert " + "(that would create per-shard load imbalance under EP). " + "Re-enable once the FFI check is relaxed." + ), + strict=False, + ) + def test_align_size_equivalence_pure_jax(self): + """For the pure-JAX backend, ``align_size > 0`` must not change the + numerical output of the forward pass: padding tokens contribute zero + to every expert GEMM output (their input rows are zeros) and are + stripped before the weighted sum. + """ + key = jax.random.PRNGKey(5) + init_key, data_key = jax.random.split(key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend="pure_jax", + dtype=DTYPE, + ) + block_no_pad = MoEBlock(align_size=0, **base_kwargs) + block_pad = MoEBlock(align_size=16, **base_kwargs) + inputs = _make_inputs(data_key) + variables = block_no_pad.init(init_key, inputs) + + out_no_pad, _ = block_no_pad.apply(variables, inputs) + out_pad, _ = block_pad.apply(variables, inputs) + assert jnp.allclose(out_no_pad, out_pad, atol=5e-2, rtol=5e-2), ( + "align_size > 0 must not change pure_jax forward output; max diff" + f" {jnp.max(jnp.abs(out_no_pad - out_pad))}" + ) + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_jit_and_determinism(self, permutation_backend): + """The block must be JIT-compilable and produce a deterministic + forward pass across repeat calls with the same params.""" + key = jax.random.PRNGKey(6) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + variables = block.init(init_key, inputs) + + @jax.jit + def forward(variables, inputs): + return block.apply(variables, inputs)[0] + + out_a = forward(variables, inputs) + out_b = forward(variables, inputs) + assert jnp.array_equal(out_a, out_b), "JITted forward is non-deterministic" diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index 92a968f061..0cd7835bcf 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -9,6 +9,7 @@ make_dot_general_cls, make_grouped_dense_cls, ) +from .moe import MoEBlock from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -18,6 +19,7 @@ "LayerNorm", "LayerNormDenseGeneral", "LayerNormMLP", + "MoEBlock", "wrap_function_in_te_state_module", "make_dot_general_cls", "make_grouped_dense_cls", diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py new file mode 100644 index 0000000000..050cbe84d0 --- /dev/null +++ b/transformer_engine/jax/flax/moe.py @@ -0,0 +1,943 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Flax Linen MoEBlock for TransformerEngine JAX. + +This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE layer +that wires together TE's fused router, a selectable token-dispatch backend +(pure-JAX ``unfused_*`` or fused Triton), TE's ``grouped_dense``, and an +optional ragged-all-to-all (A2A / A2Av) expert-parallelism strategy. + +Architecture +------------ + +The MoEBlock is decomposed into orthogonal stages so the EP wrapper can +inject collectives between them: + +* ``_route``: gate logits -> top-k routing decisions (+ aux loss). +* ``_global_permute``: scatter tokens to experts; produces + ``[num_tokens*topk + maybe_padding, hidden]`` and + per-expert ``group_sizes`` of length ``num_experts``. +* ``_expert_ffn``: three ``grouped_dense`` calls + activation. Operates + on whatever ``(rows, group_sizes, n_groups)`` it is + handed -- agnostic to whether ``n_groups`` is the + global expert count (no-EP) or the local expert + count (A2A-EP). +* ``_global_combine``: inverse of ``_global_permute`` -- gather + weighted + sum across top-k experts. + +Two top-level forward variants compose those stages: + +* ``_forward_no_ep``: route -> permute -> ffn -> combine. Each TE + primitive's ``custom_partitioning`` rule handles + DP / FSDP / TP automatically. +* ``_forward_a2a_ep``: wraps the body in :func:`jax.shard_map` and inserts + ``all_gather(group_sizes)`` + forward + ``ragged_all_to_all`` + local permute around the + FFN, plus their inverses afterwards. This is the + only place ``shard_map`` is used; A2A is the + canonical EP strategy because the in-flight NCCL + EP component will require this same data layout. + +Note on ``align_size > 0`` +-------------------------- + +Both permutation backends pad each expert's group to a multiple of +``align_size`` when requested, which is what CUBLASLt's grouped GEMM wants +for FP8 shape selection. The pure-JAX backend additionally appends a +zero-input padding tail to keep the buffer statically sized for JIT, so +``sum(group_sizes) <= sorted_inputs.shape[0]`` strictly. TE's +``grouped_dense`` FFI today asserts ``m == sum(group_sizes)`` at +``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``; relaxing that +check to ``m >= sum(group_sizes)`` (the kernel itself only iterates over +``sum(group_sizes)`` rows via ``nvte_multi_tensor_gemm``) is the cleanest +way to support ``align_size > 0`` end-to-end. Until that lands the +``align_size > 0`` tests stay xfail. +""" + +from typing import Any, Callable, NewType, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from flax import linen as nn +from jax.sharding import PartitionSpec as P + +from ..dense import grouped_dense +from ..permutation import ( + _routing_map_to_selected_experts, + compute_ragged_all_to_all_params, + compute_reverse_ragged_all_to_all_params, + local_permute_after_a2a, + local_unpermute_before_a2a, + token_combine, + token_dispatch, + unfused_token_combine, + unfused_token_dispatch, +) +from ..quantize import noop_quantizer_set +from ..router import ScoreFunction, fused_moe_aux_loss, fused_topk_with_score_function +from ..sharding import with_sharding_constraint_by_logical_axes +from .module import TransformerEngineBase, _convert_to_activation_function + +PRNGKey = Any +Shape = Tuple[int, ...] +DType = NewType("DType", jnp.dtype) +Array = NewType("Array", jnp.ndarray) +Initializer = Callable[[PRNGKey, Shape, DType], Array] + + +__all__ = ["MoEBlock"] + + +# ============================================================================= +# MoEBlock +# ============================================================================= + + +class MoEBlock(TransformerEngineBase): + """Mixture-of-Experts Flax Linen block. + + Encapsulates the full MoE forward pass: gate projection, fused top-k + routing, optional auxiliary load-balancing loss, token dispatch, + per-expert two-layer FFN via grouped GEMMs, activation, token combine, + and optional ragged-all-to-all expert parallelism. + + Two permutation backends are pluggable via ``permutation_backend``: + + * ``"pure_jax"`` (default) -- argsort-based + :func:`~transformer_engine.jax.permutation.unfused_token_dispatch` / + :func:`~transformer_engine.jax.permutation.unfused_token_combine`. + Faster than Triton in profiling for DeepSeek-style configs. + * ``"triton"`` -- TE's fused + :func:`~transformer_engine.jax.permutation.token_dispatch` / + :func:`~transformer_engine.jax.permutation.token_combine` Triton + kernels. + + Expert parallelism (``expert_parallelism_axis is not None``) uses the + **ragged-all-to-all** EP strategy (a.k.a. A2Av): each shard routes its + own tokens globally over all experts, then a forward + ``ragged_all_to_all`` exchanges per-expert chunks so each shard ends up + holding only the tokens for its local experts; after the FFN a reverse + ``ragged_all_to_all`` returns each shard's outputs to it. This matches + the layout the in-flight NCCL EP component expects. + + Parameters + ---------- + num_experts : int + Total number of experts. + num_experts_per_tok : int + Top-k value (number of experts each token is routed to). + intermediate_size : int + Per-expert FFN hidden dim. + + activation_type : str + FFN activation applied to the gate projection. Paired with the up + projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. + Resolved via :func:`flax.linen.` (``"silu"``, ``"gelu"``, + ``"relu"``, ``"swish"``, ...) plus ``"linear"`` for identity. + + score_function : str or ScoreFunction + ``"softmax"`` (default) or ``"sigmoid"`` for + :func:`fused_topk_with_score_function`. + use_pre_softmax : bool + Apply softmax before top-k when ``score_function="softmax"``. + num_groups : int + Number of routing groups for grouped top-k (DeepSeek). ``<=0`` + disables. + group_topk : int + Top-k at the group level. ``<=0`` disables. + scaling_factor : float + Scaling factor applied to output probs. + use_expert_bias : bool + If ``True``, registers a learnable ``expert_bias`` parameter of + shape ``[num_experts]`` and passes it to the fused router. The + router primitive validates that this is paired with + ``score_function="sigmoid"``. + aux_loss_coeff : float + If ``> 0``, compute and return the MoE auxiliary load-balancing + loss scalar via :func:`fused_moe_aux_loss`. ``0`` disables. + + gate_kernel_axes : tuple[str, ...] + Logical partitioning axes for the gate kernel of shape + ``[hidden, num_experts]``. + wi_kernel_axes : tuple[str, ...] + Logical partitioning axes for the ``wi_0`` and ``wi_1`` kernels of + shape ``[num_experts, hidden, intermediate]``. Default + ``("exp", "embed", "mlp")``. + wo_kernel_axes : tuple[str, ...] + Logical partitioning axes for the ``wo`` kernel of shape + ``[num_experts, intermediate, hidden]``. Default + ``("exp", "mlp", "embed")``. + input_axes : tuple[str, ...] + Logical axes used to constrain the input activation sharding at the + block boundary. ``()`` (default) means no constraint. + + expert_parallelism_axis : Optional[str] + Mesh axis along which experts are split. When set, the forward + pass is wrapped in :func:`jax.shard_map` that implements the + ragged-all-to-all EP strategy. When ``None`` (default), no + ``shard_map`` wrapper is used; each TE primitive's + ``custom_partitioning`` rule handles DP / FSDP / TP automatically. + tensor_parallelism_axis : Optional[str] + Mesh axis for tensor parallelism on the FFN intermediate dim. When + set, the output of the ``wo`` grouped GEMM is ``psum_scatter`` ed + along this axis. + + permutation_backend : str + ``"pure_jax"`` (default) or ``"triton"``. + align_size : int + Alignment for per-expert group sizes after padding. ``0`` disables + padding (the only supported configuration end-to-end today). ``>0`` + is required for quantized TE grouped GEMM whose recipe-specific + alignment must divide ``align_size``; see the module docstring for + the FFI assertion that currently blocks ``>0`` for both backends. + + dtype : jnp.dtype + Compute and parameter dtype. + kernel_init : Initializer + Initializer for all kernels (gate + per-expert FFN). Defaults to + ``variance_scaling(1.0, 'fan_in', 'truncated_normal')`` (Flax + convention). + use_bias : bool + If ``True``, registers per-expert FFN biases ``wi_0_bias``, + ``wi_1_bias``, ``wo_bias``. + """ + + # Architecture + num_experts: int = 8 + num_experts_per_tok: int = 2 + intermediate_size: int = 2048 + activation_type: str = "silu" + + # Routing + score_function: Union[str, ScoreFunction] = "softmax" + use_pre_softmax: bool = False + num_groups: int = -1 + group_topk: int = -1 + scaling_factor: float = 1.0 + use_expert_bias: bool = False + aux_loss_coeff: float = 0.0 + + # Sharding + gate_kernel_axes: Tuple[Optional[str], ...] = () + wi_kernel_axes: Tuple[Optional[str], ...] = ("exp", "embed", "mlp") + wo_kernel_axes: Tuple[Optional[str], ...] = ("exp", "mlp", "embed") + input_axes: Tuple[Optional[str], ...] = () + + # Parallelism + expert_parallelism_axis: Optional[str] = None + tensor_parallelism_axis: Optional[str] = None + # ``jax.sharding.Mesh`` to use when ``expert_parallelism_axis`` is set. + # Required for the ``shard_map`` wrapper; ignored otherwise. + mesh: Optional[Any] = None + + # Permutation + permutation_backend: str = "pure_jax" + align_size: int = 0 + + # Dtypes / init / misc + dtype: DType = jnp.float32 + kernel_init: Optional[Initializer] = None + bias_init: Initializer = nn.initializers.zeros + expert_bias_init: Initializer = nn.initializers.zeros + use_bias: bool = False + + def __post_init__(self): + if self.kernel_init is None: + object.__setattr__( + self, + "kernel_init", + nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ), + ) + if self.permutation_backend not in ("pure_jax", "triton"): + raise ValueError( + "permutation_backend must be 'pure_jax' or 'triton'," + f" got {self.permutation_backend!r}" + ) + super().__post_init__() + + # ------------------------------------------------------------------ + # Parameter registration + # ------------------------------------------------------------------ + + def _make_params(self, hidden_size: int) -> dict: + """Register module parameters and return them as a dict.""" + gate_kernel = self.param( + "gate_kernel", + nn.with_logical_partitioning(self.kernel_init, self.gate_kernel_axes), + (hidden_size, self.num_experts), + self.dtype, + ) + wi_0 = self.param( + "wi_0", + nn.with_logical_partitioning(self.kernel_init, self.wi_kernel_axes), + (self.num_experts, hidden_size, self.intermediate_size), + self.dtype, + ) + wi_1 = self.param( + "wi_1", + nn.with_logical_partitioning(self.kernel_init, self.wi_kernel_axes), + (self.num_experts, hidden_size, self.intermediate_size), + self.dtype, + ) + wo = self.param( + "wo", + nn.with_logical_partitioning(self.kernel_init, self.wo_kernel_axes), + (self.num_experts, self.intermediate_size, hidden_size), + self.dtype, + ) + params: dict = { + "gate_kernel": gate_kernel, + "wi_0": wi_0, + "wi_1": wi_1, + "wo": wo, + } + if self.use_bias: + params["wi_0_bias"] = self.param( + "wi_0_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "mlp")), + (self.num_experts, self.intermediate_size), + self.dtype, + ) + params["wi_1_bias"] = self.param( + "wi_1_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "mlp")), + (self.num_experts, self.intermediate_size), + self.dtype, + ) + params["wo_bias"] = self.param( + "wo_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "embed")), + (self.num_experts, hidden_size), + self.dtype, + ) + if self.use_expert_bias: + params["expert_bias"] = self.param( + "expert_bias", + nn.with_logical_partitioning(self.expert_bias_init, ("exp",)), + (self.num_experts,), + self.dtype, + ) + return params + + # ------------------------------------------------------------------ + # Entry point + # ------------------------------------------------------------------ + + @nn.compact + def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: + """Run the MoE forward pass. + + Parameters + ---------- + inputs : jnp.ndarray + Input tensor of shape ``[batch, sequence, hidden]``. + + Returns + ------- + output : jnp.ndarray + Output tensor of shape ``[batch, sequence, hidden]``. + aux_loss : Optional[jnp.ndarray] + Scalar auxiliary load-balancing loss when + ``aux_loss_coeff > 0``, else ``None``. + """ + assert inputs.ndim == 3, ( + f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" + ) + inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) + + _, _, hidden_size = inputs.shape + params = self._make_params(hidden_size) + + # The gate runs OUTSIDE any EP shard_map: under EP each shard + # projects only its local slice of tokens, producing local gate + # logits with the same per-shard layout as ``inputs``. + gate_logits = self._gate(inputs, params["gate_kernel"]) + + if self.expert_parallelism_axis is None: + output, aux_loss = self._forward_no_ep(inputs, gate_logits, params) + else: + output, aux_loss = self._forward_a2a_ep(inputs, gate_logits, params) + + if self.aux_loss_coeff <= 0.0: + aux_loss = None + return output, aux_loss + + # ------------------------------------------------------------------ + # Gate + # ------------------------------------------------------------------ + + def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray: + """Linear gate projection ``inputs @ gate_kernel``. + + Kept as a plain ``einsum`` (not ``DenseGeneral``) so it composes + cleanly with the EP shard_map: the gate runs in the outer + (pre-shard_map) scope and its output passes through the + ``shard_map`` boundary unchanged. + """ + kernel = gate_kernel.astype(inputs.dtype) + return jnp.einsum("bsh,he->bse", inputs, kernel) + + # ------------------------------------------------------------------ + # Route + # ------------------------------------------------------------------ + # + # The router is split into two pieces so the EP path can compute + # aux_loss over global (cross-shard) statistics without re-running + # the main top-k path. ``_route_topk`` returns the per-token routing + # decisions (used by ``_global_permute``) and ``_compute_aux_loss`` + # returns the scalar load-balancing loss given the (possibly + # gathered) logits. + + def _route_topk( + self, + logits_2d: jnp.ndarray, + expert_bias: Optional[jnp.ndarray], + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Run the fused router top-k selection.""" + sparse_probs, routing_map = fused_topk_with_score_function( + logits_2d, + topk=self.num_experts_per_tok, + use_pre_softmax=self.use_pre_softmax, + num_groups=self.num_groups, + group_topk=self.group_topk, + scaling_factor=self.scaling_factor, + score_function=self.score_function, + expert_bias=expert_bias, + ) + sparse_probs = sparse_probs.astype(self.dtype) + return sparse_probs, routing_map + + def _compute_aux_loss( + self, + logits_2d: jnp.ndarray, + ) -> Optional[jnp.ndarray]: + """Compute the MoE auxiliary load-balancing loss. + + The score-for-aux kernel has no data dependency on the main + routing kernel, so XLA can overlap them on the GPU. + + ``logits_2d`` should be the *full* logits tensor over the global + token batch -- under EP the caller is responsible for + :func:`jax.lax.all_gather` ing the logits before calling this so + the aux_loss formula + ``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])`` + sees the global ``T`` and the global ``tokens_per_expert``. + """ + if self.aux_loss_coeff <= 0.0: + return None + aux_scores, aux_routing_map = fused_topk_with_score_function( + logits_2d.astype(jnp.float32), + topk=self.num_experts_per_tok, + score_function=self.score_function, + compute_aux_scores=True, + ) + aux_tokens_per_expert = jnp.sum( + aux_routing_map.astype(jnp.int32), axis=0 + ) + return fused_moe_aux_loss( + aux_scores.astype(jnp.float32), + aux_tokens_per_expert, + topk=self.num_experts_per_tok, + coeff=self.aux_loss_coeff, + ) + + # ------------------------------------------------------------------ + # Global permute (route -> token dispatch) + # ------------------------------------------------------------------ + + def _global_permute( + self, + inputs_2d: jnp.ndarray, + sparse_probs: jnp.ndarray, + routing_map: jnp.ndarray, + ) -> dict: + """Dispatch tokens to the global expert axis. + + Returns a permutation-result dict suitable both for the no-EP + forward (where the same buffer feeds ``_expert_ffn`` directly) and + for the A2A-EP path (where the buffer is sliced + sent over the EP + axis before the FFN). The dict carries the per-backend opaque + state needed to invert the dispatch in :meth:`_global_combine`. + + The output dict layout is:: + + { + "backend": "pure_jax" | "triton", + "sorted_inputs": [buffer_size, hidden], + "group_sizes": [num_experts], # per-expert, + # length == E always. + "perm_state": UnfusedPermState | None, # pure_jax + "row_id_map": jnp.ndarray | None, # triton + "pad_offsets": jnp.ndarray | None, # triton + "routing_weights": jnp.ndarray | None, # pure_jax + "merging_probs": jnp.ndarray | None, # triton + } + """ + num_tokens = inputs_2d.shape[0] + topk = self.num_experts_per_tok + + if self.permutation_backend == "pure_jax": + selected_experts, routing_weights = _routing_map_to_selected_experts( + sparse_probs, routing_map, topk + ) + sorted_inputs, perm_state, group_sizes = unfused_token_dispatch( + inputs_2d, + selected_experts, + num_experts=self.num_experts, + num_experts_per_tok=topk, + align_size=self.align_size, + ) + return { + "backend": "pure_jax", + "sorted_inputs": sorted_inputs, + "group_sizes": group_sizes, + "perm_state": perm_state, + "routing_weights": routing_weights, + } + + # triton + num_out_tokens = num_tokens * topk + align_size_arg = self.align_size if self.align_size > 0 else None + ( + sorted_inputs, + _permuted_probs, + row_id_map, + pad_offsets, + group_sizes, + ) = token_dispatch( + inputs_2d, + routing_map, + num_out_tokens=num_out_tokens, + probs=sparse_probs, + align_size=align_size_arg, + ) + return { + "backend": "triton", + "sorted_inputs": sorted_inputs, + "group_sizes": group_sizes, + "row_id_map": row_id_map, + "pad_offsets": pad_offsets, + "merging_probs": sparse_probs, + } + + # ------------------------------------------------------------------ + # Expert FFN (three grouped_dense calls + activation) + # ------------------------------------------------------------------ + + def _expert_ffn( + self, + sorted_inputs: jnp.ndarray, + group_sizes: jnp.ndarray, + params: dict, + n_groups: int, + ) -> jnp.ndarray: + """Run the per-expert SwiGLU-style FFN over a permuted buffer. + + Parameters + ---------- + sorted_inputs : jnp.ndarray + Permuted tokens of shape ``[buffer_size, hidden]`` (rows + grouped by expert). + group_sizes : jnp.ndarray + Per-group token counts of shape ``[n_groups]``. + ``sum(group_sizes)`` must equal ``buffer_size`` (TE + ``grouped_dense`` FFI assertion at + ``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``). + params : dict + Block parameters from :meth:`_make_params`. Reads ``wi_0``, + ``wi_1``, ``wo``, and the optional bias entries. + n_groups : int + Number of expert groups. Equals ``self.num_experts`` for the + no-EP path and ``num_experts // num_ep`` for the A2A-EP path. + Used to size the per-call quantizer set so the FP8 metadata + tensors match ``group_sizes``. + + Returns + ------- + expert_outputs : jnp.ndarray + ``[buffer_size, hidden]``. + """ + wi_0 = params["wi_0"] + wi_1 = params["wi_1"] + wo = params["wo"] + + # Each grouped_dense call gets its own quantizer_set with + # n_groups matching ``group_sizes``; this keeps the FP8 meta + # tensors correctly sized in both no-EP and A2A-EP cases. + q_set_w0 = self.generate_quantizer_set(postfix="_w0", n_groups=n_groups) + q_set_w1 = self.generate_quantizer_set(postfix="_w1", n_groups=n_groups) + q_set_wo = self.generate_quantizer_set(postfix="_wo", n_groups=n_groups) + + # Cast kernels to the activation dtype when no FP8 quantization + # is active (mirrors DenseGeneral). + if q_set_w0 == noop_quantizer_set: + wi_0 = wi_0.astype(sorted_inputs.dtype) + if q_set_w1 == noop_quantizer_set: + wi_1 = wi_1.astype(sorted_inputs.dtype) + if q_set_wo == noop_quantizer_set: + wo = wo.astype(sorted_inputs.dtype) + + # ``grouped_dense`` accepts per-expert bias of shape (G, N); it + # adds ``bias[i]`` to the ``group_sizes[i]`` rows belonging to + # expert ``i`` in the permuted layout. + wi_0_bias = params.get("wi_0_bias") if self.use_bias else None + wi_1_bias = params.get("wi_1_bias") if self.use_bias else None + wo_bias = params.get("wo_bias") if self.use_bias else None + + layer_w0 = grouped_dense( + sorted_inputs, + wi_0, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wi_0_bias, + quantizer_set=q_set_w0, + ) + layer_w1 = grouped_dense( + sorted_inputs, + wi_1, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wi_1_bias, + quantizer_set=q_set_w1, + ) + + act_fn = _convert_to_activation_function(self.activation_type) + intermediate = act_fn(layer_w0) * layer_w1 + + expert_outputs = grouped_dense( + intermediate, + wo, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wo_bias, + quantizer_set=q_set_wo, + ) + return expert_outputs + + # ------------------------------------------------------------------ + # Global combine (token combine -> back to [B, S, H]) + # ------------------------------------------------------------------ + + def _global_combine( + self, + expert_outputs: jnp.ndarray, + perm_result: dict, + batch_size: int, + sequence_length: int, + ) -> jnp.ndarray: + """Inverse of :meth:`_global_permute`. + + Gathers per-expert outputs back into ``[batch, sequence, hidden]`` + and applies the per-token weighted sum across the top-k experts. + """ + backend = perm_result["backend"] + if backend == "pure_jax": + return unfused_token_combine( + expert_outputs, + perm_result["perm_state"], + perm_result["routing_weights"], + num_experts_per_tok=self.num_experts_per_tok, + batch_size=batch_size, + sequence_length=sequence_length, + ) + # triton + out_2d = token_combine( + expert_outputs, + perm_result["row_id_map"], + merging_probs=perm_result["merging_probs"], + pad_offsets=perm_result["pad_offsets"], + ) + hidden_size = out_2d.shape[-1] + return out_2d.reshape(batch_size, sequence_length, hidden_size).astype( + self.dtype + ) + + # ------------------------------------------------------------------ + # No-EP forward + # ------------------------------------------------------------------ + + def _forward_no_ep( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + params: dict, + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """Single-shard or DP/FSDP/TP forward (no shard_map wrapper). + + DP / FSDP / TP all flow through each TE primitive's + ``custom_partitioning`` rule -- there is no cross-primitive + collective that the rules cannot express on their own, so a + ``shard_map`` is unnecessary here. + """ + batch_size, sequence_length, hidden_size = inputs.shape + inputs_2d = inputs.reshape(-1, hidden_size) + logits_2d = gate_logits.reshape(-1, self.num_experts) + + sparse_probs, routing_map = self._route_topk( + logits_2d, params.get("expert_bias") + ) + aux_loss = self._compute_aux_loss(logits_2d) + perm = self._global_permute(inputs_2d, sparse_probs, routing_map) + expert_outputs = self._expert_ffn( + perm["sorted_inputs"], + perm["group_sizes"], + params, + n_groups=self.num_experts, + ) + output = self._global_combine( + expert_outputs, perm, batch_size, sequence_length + ) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + return output, aux_loss + + # ------------------------------------------------------------------ + # A2A (ragged-all-to-all) EP forward + # ------------------------------------------------------------------ + + def _forward_a2a_ep( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + params: dict, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Wrap the body in a ``shard_map`` that runs a forward + ``ragged_all_to_all`` (A2A / A2Av) around the FFN. + + For each EP shard the wrapper: + + 1. Routes the shard's local tokens **globally** over all + ``num_experts`` experts (no roll, no local-mask -- every shard + sees the full expert axis). + 2. ``all_gather`` s its per-expert ``group_sizes`` so all shards + know the complete ``[num_ep, num_experts]`` token-count matrix. + 3. Forward ``ragged_all_to_all`` over the EP axis: each shard + sends per-expert chunks to the shard that owns those experts, + and receives chunks for its own ``num_experts // num_ep`` + local experts from every other shard. + 4. Reorders the received buffer from ``(source_shard, expert)`` + to ``(expert, source_shard)`` ordering so each local expert's + tokens are contiguous. + 5. Runs the three ``grouped_dense`` calls + activation over the + ``E_local``-group buffer. + 6. Reverses the local reorder. + 7. Reverse ``ragged_all_to_all`` over EP returns each shard's + token outputs to it. + 8. Inverts the global permute and applies the top-k weighted sum. + """ + from jax.experimental.shard_map import shard_map + + ep_axis = self.expert_parallelism_axis + if self.mesh is None: + raise ValueError( + "MoEBlock.expert_parallelism_axis is set; `mesh` must also" + " be provided so the EP shard_map can be built." + ) + mesh = self.mesh + num_ep = mesh.shape[ep_axis] + assert self.num_experts % num_ep == 0, ( + f"num_experts={self.num_experts} must be divisible by EP" + f" size={num_ep}" + ) + num_experts_local = self.num_experts // num_ep + + # Pre-compute the worst-case A2A receive buffer size (compile-time + # constant). Each shard contributes ``b_l*S*topk = B*S*topk/num_ep`` + # token-expert pairs across all experts; the worst case for one + # shard is "every global pair lands on this shard's local + # experts" -- ``num_ep * (B*S*topk/num_ep) = B*S*topk`` rows. JIT + # needs this static, so we use the global ``batch_size`` from the + # outer scope (sharded layouts don't change it). + global_batch_size, sequence_length, _hidden = inputs.shape + topk = self.num_experts_per_tok + recv_buffer_rows = global_batch_size * sequence_length * topk + + # Pack everything that crosses the shard_map boundary into a dict + # pytree. shard_map fully supports pytrees: ``in_specs`` must + # structurally match ``captured`` and we build them in lockstep + # so adding/removing an optional bias is one ``dict[name] = ...``. + captured: dict = { + "inputs": inputs, + "gate_logits": gate_logits, + "wi_0": params["wi_0"], + "wi_1": params["wi_1"], + "wo": params["wo"], + } + in_specs: dict = { + "inputs": P(ep_axis, None, None), + "gate_logits": P(ep_axis, None, None), + "wi_0": P(ep_axis, None, None), + "wi_1": P(ep_axis, None, None), + "wo": P(ep_axis, None, None), + } + if "expert_bias" in params: + captured["expert_bias"] = params["expert_bias"] + in_specs["expert_bias"] = P(ep_axis) + if "wi_0_bias" in params: + for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): + captured[name] = params[name] + in_specs[name] = P(ep_axis, None) + + def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: + shard_id = jax.lax.axis_index(ep_axis) + + # -- Stage 1: per-shard route + global permute over all E -- + # Inside the shard_map body each input has its EP axis already + # consumed, so ``local_inputs.shape == [B/num_ep, S, H]``. + local_inputs = local["inputs"] + local_logits = local["gate_logits"] + local_b, local_s, local_h = local_inputs.shape + inputs_2d = local_inputs.reshape(-1, local_h) + logits_2d = local_logits.reshape(-1, self.num_experts) + + # The router operates over the full expert axis, so the + # EP-sharded ``expert_bias`` (in_spec ``P(ep_axis)``) must be + # all-gathered before being passed in. + if "expert_bias" in local: + full_expert_bias = jax.lax.all_gather( + local["expert_bias"], axis_name=ep_axis, tiled=True + ) + else: + full_expert_bias = None + sparse_probs, routing_map = self._route_topk( + logits_2d, full_expert_bias + ) + + # aux_loss must see the global token batch and the global + # tokens_per_expert: its formula ``E*coeff/(k*T^2) * sum_i( + # sum_t(probs[t,i]) * tokens[i])`` is not shard-decomposable + # (the sum_t * tokens product is data-dependent across + # shards). Cheapest fix: gather logits along the EP axis and + # run the aux-loss kernel on the global tensor. The aux + # branch has no data dependency on the main routing path so + # XLA can overlap the two on the GPU. + if self.aux_loss_coeff > 0.0: + global_logits_2d = jax.lax.all_gather( + logits_2d, axis_name=ep_axis, axis=0, tiled=True + ) + aux_loss = self._compute_aux_loss(global_logits_2d) + else: + aux_loss = None + + perm = self._global_permute(inputs_2d, sparse_probs, routing_map) + global_group_sizes = perm["group_sizes"] # [E] + + # -- Stage 2: gather per-expert counts across the EP axis -- + all_shards_tokens_per_expert = jax.lax.all_gather( + global_group_sizes[None, :], + axis_name=ep_axis, + axis=0, + tiled=True, + ) # [num_ep, num_experts] + + # -- Stage 3: forward ragged_all_to_all over EP -- + in_off, send_sz, out_off, recv_sz = compute_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep + ) + recv_buf = jnp.zeros( + (recv_buffer_rows, local_h), + dtype=perm["sorted_inputs"].dtype, + ) + x_recv = jax.lax.ragged_all_to_all( + perm["sorted_inputs"], + recv_buf, + in_off, + send_sz, + out_off, + recv_sz, + axis_name=ep_axis, + ) + + # -- Stage 4: local permute (source_shard, expert) -> (expert, shard) + sorted_x, local_group_sizes, local_perm_state = ( + local_permute_after_a2a( + x_recv, + all_shards_tokens_per_expert, + shard_id, + num_ep, + ) + ) + + # -- Stage 5: per-expert FFN (E_local groups) -- + local_params: dict = { + "wi_0": local["wi_0"], + "wi_1": local["wi_1"], + "wo": local["wo"], + } + if "wi_0_bias" in local: + local_params["wi_0_bias"] = local["wi_0_bias"] + local_params["wi_1_bias"] = local["wi_1_bias"] + local_params["wo_bias"] = local["wo_bias"] + expert_outputs = self._expert_ffn( + sorted_x, + local_group_sizes, + local_params, + n_groups=num_experts_local, + ) + + # -- Stage 6: invert local permute -- + x_send_back = local_unpermute_before_a2a( + expert_outputs, local_perm_state + ) + + # -- Stage 7: reverse ragged_all_to_all over EP -- + in_off_r, send_sz_r, out_off_r, recv_sz_r = ( + compute_reverse_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep + ) + ) + send_back_buf = jnp.zeros_like(perm["sorted_inputs"]) + y_back = jax.lax.ragged_all_to_all( + x_send_back, + send_back_buf, + in_off_r, + send_sz_r, + out_off_r, + recv_sz_r, + axis_name=ep_axis, + ) + + # -- Stage 8: invert global permute, weighted sum over top-k -- + output = self._global_combine( + y_back, perm, batch_size=local_b, sequence_length=local_s + ) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + + # ``out_specs`` must match the returned pytree structurally, + # so always emit a real scalar for aux_loss; the outer + # ``__call__`` re-strips it to None when aux_loss_coeff <= 0. + if aux_loss is None: + aux_loss = jnp.zeros((), dtype=self.dtype) + return output, aux_loss + + # ``check_rep=False`` disables shard_map's invariant that any + # output declared as ``P()`` is replicated across ``ep_axis``. + # We use ``axis_index(ep_axis)`` inside ``_a2a_fn`` so the body + # is genuinely non-replicated, which would otherwise (correctly) + # fail the check. ``ragged_all_to_all`` already produces the + # right cross-shard semantics; this is the standard JAX escape + # hatch when collectives + per-shard logic coexist. + return shard_map( + _a2a_fn, + mesh=mesh, + in_specs=(in_specs,), + out_specs=(P(ep_axis, None, None), P()), + check_rep=False, + )(captured) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 6a0a3229d9..f4599a7b8f 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -7,6 +7,17 @@ This module provides high-level token dispatch and combine operations for Mixture of Experts (MoE) models with proper automatic differentiation support. +Two backends are offered: + +* Fused, Triton-backed ``token_dispatch`` / ``token_combine`` - uses the + Triton kernels in ``transformer_engine.jax.triton_extensions.permutation``. +* Unfused, pure-JAX ``unfused_token_dispatch`` / ``unfused_token_combine`` - + uses only ``jnp.argsort`` + gather and is therefore compiled as plain XLA. + +Both backends support optional alignment padding (``align_size > 0``) so each +expert's group size is a multiple of ``align_size``, which is required for +quantized grouped GEMMs. + Token Dispatch (Permute): - Forward: Permute tokens according to routing map (scatter to experts) - Backward: Unpermute gradients (gather from experts) @@ -17,7 +28,7 @@ """ from functools import partial -from typing import Optional, Tuple +from typing import NamedTuple, Optional, Tuple import jax import jax.numpy as jnp @@ -38,6 +49,14 @@ "token_dispatch", "token_combine", "sort_chunks_by_index", + "unfused_token_dispatch", + "unfused_token_combine", + "UnfusedPermState", + # Ragged-all-to-all expert-parallelism helpers + "compute_ragged_all_to_all_params", + "compute_reverse_ragged_all_to_all_params", + "local_permute_after_a2a", + "local_unpermute_before_a2a", ] @@ -73,9 +92,7 @@ def token_dispatch( Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. Values: 1 = routed, 0 = not routed. num_out_tokens : int - The number of output tokens after permutation (before padding). For the dropless - case, this should be equal to the sum of routing_map. Must be provided explicitly - for JIT compatibility since output shape must be known at compile time. + Number of output tokens (rows in the permuted buffer, before padding). Must be > 0, e.g. int(jnp.sum(routing_map)) or num_tokens * top_k. Must be a compile-time constant for JIT. probs : Optional[jnp.ndarray] Optional routing probabilities of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. If provided, permuted_probs will be returned. @@ -121,6 +138,8 @@ def token_dispatch( ((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size This accounts for the maximum possible padding when each expert needs (align_size - 1) extra tokens to align, rounded down to align_size for buffer alignment. + + Non-positive num_out_tokens (e.g. -1) raises AssertionError. """ use_padding = align_size is not None num_experts = routing_map.shape[-1] @@ -134,6 +153,11 @@ def token_dispatch( else: worst_case_out_tokens = num_out_tokens + assert num_out_tokens > 0, ( + f"token_dispatch requires num_out_tokens > 0, got {num_out_tokens}. " + "Use int(jnp.sum(routing_map)) or num_tokens * top_k." + ) + return _token_dispatch( inp, routing_map, probs, num_out_tokens, worst_case_out_tokens, align_size, use_padding ) @@ -650,3 +674,654 @@ def _sort_chunks_by_index_bwd_rule( _sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule) + + +# ============================================================================= +# Unfused (pure-JAX) token dispatch / combine +# ============================================================================= +# +# The following implementations use only ``jnp.argsort`` + gather and compile +# to plain XLA. They are a drop-in alternative to ``token_dispatch`` / +# ``token_combine`` above, differing only in input/output conventions (the +# fused path takes ``routing_map`` and ``sparse_probs`` over all experts; the +# unfused path takes dense ``selected_experts`` and per-token ``weights`` of +# shape ``[..., topk]``). + + +# ----------------------------------------------------------------------------- +# Custom-VJP argsort-based gather. +# +# ``inputs[sort_indices]`` has a known inverse: ``output[argsort(sort_indices)]``. +# Using a custom VJP lets the backward pass exploit that inverse instead of +# relying on the compiler to discover it from the scatter-style default +# gradient of a gather, which is typically less efficient. + + +@jax.custom_vjp +def _sort_activations(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array: + """Sort ``inputs`` along the leading dim by ``sort_indices``.""" + assert inputs.shape[0] == sort_indices.shape[0], ( + f"inputs.shape[0]={inputs.shape[0]} must match" + f" sort_indices.shape[0]={sort_indices.shape[0]}" + ) + with jax.named_scope("unfused_sort_activations"): + return inputs[sort_indices, ...] + + +def _sort_activations_fwd( + inputs: jax.Array, sort_indices: jax.Array +) -> Tuple[jax.Array, jax.Array]: + return _sort_activations(inputs, sort_indices), sort_indices + + +def _sort_activations_bwd( + residuals: jax.Array, grads: jax.Array +) -> Tuple[jax.Array, None]: + sort_indices = residuals + # Inverse permutation: gather-by-argsort undoes the forward gather. + return _sort_activations(grads, jnp.argsort(sort_indices)), None + + +_sort_activations.defvjp(_sort_activations_fwd, _sort_activations_bwd) + + +def _routing_map_to_selected_experts( + sparse_probs: jnp.ndarray, + routing_map: jnp.ndarray, + topk: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Convert ``(sparse_probs, routing_map)`` from TE's fused router to the + ``(selected_experts, weights)`` format consumed by + :func:`unfused_token_dispatch`. + + ``routing_map`` is a boolean mask of shape ``[num_tokens, num_experts]`` + with exactly ``topk`` ``True`` positions per row. + """ + # Argsort on a bool tensor places ``True`` rows last (False=0 < True=1), + # so the last ``topk`` indices are the selected expert IDs. + selected_experts = jnp.argsort(routing_map, axis=-1)[..., -topk:] + weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) + return selected_experts, weights + + +# ----------------------------------------------------------------------------- +# Permutation state carried from dispatch to combine. + + +class UnfusedPermState(NamedTuple): + """Opaque state produced by :func:`unfused_token_dispatch`. + + Attributes + ---------- + sorted_indices : jnp.ndarray + The argsort indices used in the forward sort. Needed to reverse the + permutation in :func:`unfused_token_combine`. Shape + ``[num_real_tokens + padding_size]``. + num_real_tokens : int + Number of real (non-padding) permuted tokens, i.e. + ``batch_size * sequence_length * num_experts_per_tok``. Compile-time + constant. + padding_size : int + Number of alignment-padding tokens appended to the sort buffer. Equals + ``num_experts * (align_size - 1)`` when ``align_size > 0``, else ``0``. + Compile-time constant. + """ + + sorted_indices: jax.Array + num_real_tokens: int + padding_size: int + + +# ----------------------------------------------------------------------------- +# Dispatch (permute) + + +def unfused_token_dispatch( + inputs: jnp.ndarray, + selected_experts: jnp.ndarray, + num_experts: int, + num_experts_per_tok: int, + align_size: int = 0, + roll_to_expert_id: Optional[int] = None, +) -> Tuple[jnp.ndarray, UnfusedPermState, jnp.ndarray]: + """Pure-JAX ``argsort``-based token dispatch. + + Parameters + ---------- + inputs : jnp.ndarray + Input tensor of shape ``[num_tokens, hidden_size]`` (or + ``[batch, seq, hidden]``; it will be flattened). + selected_experts : jnp.ndarray + Per-token expert IDs, shape ``[num_tokens, num_experts_per_tok]`` (or + ``[batch, seq, num_experts_per_tok]``). Integer dtype. + num_experts : int + Total number of experts. + num_experts_per_tok : int + Top-k. Must equal ``selected_experts.shape[-1]``. + align_size : int, default 0 + Alignment for each expert's group size. ``0`` disables padding; a value + ``> 0`` appends a static-size padding buffer so each resulting group + size is a multiple of ``align_size`` (required for quantized grouped + GEMM). + roll_to_expert_id : Optional[int] + If provided, rotates expert IDs by ``-roll_to_expert_id`` modulo + ``num_experts`` before the sort (ring-of-experts EP). The returned + ``group_sizes`` is rolled to match. + + Returns + ------- + sorted_inputs : jnp.ndarray + Permuted tokens grouped by expert, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : UnfusedPermState + State needed by :func:`unfused_token_combine`. + group_sizes : jnp.ndarray + Token count per expert, shape ``[num_experts]``. Each entry is a + multiple of ``align_size`` when ``align_size > 0``. + """ + assert num_experts_per_tok == selected_experts.shape[-1], ( + f"num_experts_per_tok={num_experts_per_tok} must match" + f" selected_experts.shape[-1]={selected_experts.shape[-1]}" + ) + assert align_size >= 0, f"align_size must be >= 0, got {align_size}" + + hidden_size = inputs.shape[-1] + inputs_2d = inputs.reshape(-1, hidden_size) + num_tokens = inputs_2d.shape[0] + num_real_tokens = num_tokens * num_experts_per_tok + + flatten_selected_experts = jnp.ravel(selected_experts) + + if align_size > 0: + # Per-expert token count, and how many extra tokens each expert needs + # to become aligned to ``align_size``. Using + # ``(align - count % align) % align`` gives 0 (not ``align``) when + # already aligned, so we never exceed the per-expert slot capacity of + # ``align_size - 1``. + token_count_per_expert = jnp.bincount( + flatten_selected_experts, length=num_experts + ) + padding_tokens_required_per_expert = ( + (align_size - (token_count_per_expert % align_size)) % align_size + ) + + # Build a static-size padding buffer of shape + # ``[num_experts * (align_size - 1)]``. Each expert ``i`` owns a slot + # of ``align_size - 1`` positions (worst-case padding, which occurs + # when ``token_count[i] % align_size == 1``). Within slot ``i``, + # positions ``[0, padding_needed)`` are assigned expert ``i`` and act + # as real padding; the rest are assigned to ``num_experts - 1`` as + # overflow placeholders that keep the buffer statically sized for JIT. + max_padding_per_expert = align_size - 1 + max_total_padding_size = num_experts * max_padding_per_expert + positions = jnp.arange(max_total_padding_size) + expert_for_pos = positions // max_padding_per_expert + offset_in_slot = positions % max_padding_per_expert + padding_needed = padding_tokens_required_per_expert[expert_for_pos] + flatten_padding_selected_experts = jnp.where( + offset_in_slot < padding_needed, + expert_for_pos, + num_experts - 1, + ) + + flatten_selected_experts = jnp.concatenate( + [flatten_selected_experts, flatten_padding_selected_experts], axis=0 + ) + + if roll_to_expert_id is not None: + flatten_selected_experts = ( + flatten_selected_experts - roll_to_expert_id + ) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + # Pad inputs with zeros so the sort operand shape matches the expanded + # selected-experts vector. + replicated_inputs_2d = jnp.pad( + replicated_inputs_2d, + pad_width=((0, max_total_padding_size), (0, 0)), + mode="constant", + constant_values=0.0, + ) + + sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts) + + # Compute ``group_sizes`` directly from counts rather than via + # ``bincount(flatten_selected_experts)``: the overflow placeholder + # tokens would inflate ``group_sizes[num_experts - 1]``, breaking the + # alignment guarantee. Direct computation gives each expert exactly + # ``ceil(count / align) * align`` tokens. + group_sizes = token_count_per_expert + padding_tokens_required_per_expert + + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = max_total_padding_size + else: + if roll_to_expert_id is not None: + flatten_selected_experts = ( + flatten_selected_experts - roll_to_expert_id + ) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts) + + group_sizes = jnp.bincount(flatten_selected_experts, length=num_experts) + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = 0 + + perm_state = UnfusedPermState( + sorted_indices=sorted_selected_experts, + num_real_tokens=num_real_tokens, + padding_size=padding_size, + ) + return sorted_inputs, perm_state, group_sizes + + +# ----------------------------------------------------------------------------- +# Combine (unpermute + weighted sum) + + +def unfused_token_combine( + expert_outputs: jnp.ndarray, + perm_state: UnfusedPermState, + routing_weights: jnp.ndarray, + num_experts_per_tok: int, + batch_size: int, + sequence_length: int, +) -> jnp.ndarray: + """Pure-JAX ``argsort``-based token combine. + + Reverses the permutation performed by :func:`unfused_token_dispatch`, + strips any alignment-padding rows appended during dispatch, and applies a + per-token weighted sum across the top-k experts. + + Parameters + ---------- + expert_outputs : jnp.ndarray + Output of the expert FFN, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : UnfusedPermState + State returned by :func:`unfused_token_dispatch`. + routing_weights : jnp.ndarray + Top-k routing weights, shape ``[batch*seq, num_experts_per_tok]`` + (or broadcastable to it after a ``reshape``). + num_experts_per_tok : int + Top-k. + batch_size : int + Original batch size. + sequence_length : int + Original sequence length. + + Returns + ------- + output : jnp.ndarray + Combined output tensor of shape ``[batch_size, sequence_length, hidden_size]``. + """ + # Reverse the permutation: ``output[argsort(sorted_indices)]`` undoes + # ``input[sorted_indices]``. + unsort_intermediate = _sort_activations( + expert_outputs, + jnp.argsort(perm_state.sorted_indices), + ) + + # Strip alignment padding tokens appended during dispatch. After unsorting, + # the first ``num_real_tokens`` rows hold the real per-(token, top-k) + # outputs; any trailing rows are padding placeholders (zeros) and must be + # discarded before the reshape below. + if perm_state.padding_size > 0: + unsort_intermediate = unsort_intermediate[: perm_state.num_real_tokens] + + hidden_size = unsort_intermediate.shape[-1] + reshaped_weights = jnp.reshape(routing_weights, (-1, num_experts_per_tok)) + reshaped_intermediate = jnp.reshape( + unsort_intermediate, (reshaped_weights.shape[0], num_experts_per_tok, hidden_size) + ) + + # Cast weights to match intermediate dtype (weighted sum happens in + # intermediate dtype; callers can upcast before calling if higher + # precision weight-sum is desired). + reshaped_weights = reshaped_weights.astype(reshaped_intermediate.dtype) + with jax.named_scope("unfused_weight_sum"): + output = jnp.einsum( + "BKE,BK -> BE", + reshaped_intermediate, + reshaped_weights, + ) + return output.reshape(batch_size, sequence_length, hidden_size) + + +# ============================================================================= +# Ragged-all-to-all expert-parallelism helpers +# ============================================================================= +# +# These helpers support the ragged-all-to-all (A2A / A2Av) EP strategy used by +# :class:`transformer_engine.jax.flax.MoEBlock`. The forward EP path looks +# like:: +# +# route -> global_permute -> AG(group_sizes, ep) +# -> ragged_all_to_all(fwd, ep) +# -> local_permute_after_a2a +# -> grouped_dense x3 + activation +# -> local_unpermute_before_a2a +# -> ragged_all_to_all(reverse, ep) +# -> global_combine +# +# The two ``compute_*_ragged_all_to_all_params`` functions translate +# ``all_shards_tokens_per_expert`` (an EP-axis ``all_gather`` of each shard's +# global ``group_sizes``) into the four ``ragged_all_to_all`` arguments +# (``input_offsets``, ``send_sizes``, ``output_offsets``, ``recv_sizes``). +# ``shard_id`` may be a traced value (e.g. from :func:`jax.lax.axis_index`), +# which is why every slice into ``all_shards_tokens_per_expert`` uses +# :func:`jax.lax.dynamic_slice`. +# +# These functions are pure JAX (no MaxText / TE dependencies) and equivalent +# to :func:`maxtext.layers.te_permutation.compute_ragged_all_to_all_params` +# / :func:`compute_reverse_ragged_all_to_all_params`. + + +def compute_ragged_all_to_all_params( + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Forward-direction ragged_all_to_all parameters. + + Computes the four index/size arrays that :func:`jax.lax.ragged_all_to_all` + consumes for the **forward** EP shuffle, where each shard sends its + expert-grouped tokens to the shard that owns those experts. + + Parameters + ---------- + all_shards_tokens_per_expert : jnp.ndarray + Per-shard, per-expert token counts gathered across the EP axis. Shape + ``[num_expert_shards, num_experts]`` and integer dtype. + shard_id : jnp.ndarray + Index of the current shard along the EP axis (typically + :func:`jax.lax.axis_index` of the EP axis). Must be a 0-d integer. + num_expert_shards : int + Static EP-axis size. Must match + ``all_shards_tokens_per_expert.shape[0]``. + + Returns + ------- + input_offsets : jnp.ndarray + Shape ``[num_expert_shards]``. Cumulative ``send_sizes`` (with a + leading 0) -- where in the local source buffer each destination + shard's chunk begins. + send_sizes : jnp.ndarray + Shape ``[num_expert_shards]``. ``send_sizes[i]`` is the number of + tokens this shard sends to shard ``i`` (= the sum of token counts + for the experts owned by shard ``i``). + output_offsets : jnp.ndarray + Shape ``[num_expert_shards]``. ``output_offsets[i]`` is the row in + shard ``i``'s receive buffer where this shard's contribution should + land. Sender-side semantics, per :func:`jax.lax.ragged_all_to_all`. + recv_sizes : jnp.ndarray + Shape ``[num_expert_shards]``. ``recv_sizes[i]`` is the number of + tokens shard ``i`` sends to this shard. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert num_experts % num_expert_shards == 0, ( + f"num_experts={num_experts} must be divisible by num_expert_shards" + f"={num_expert_shards}" + ) + local_expert_size = num_experts // num_expert_shards + + # This shard's row of the gathered table, reshaped so axis 0 indexes the + # destination shard and axis 1 indexes its local experts. + local_tokens_per_expert = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(shard_id, 0), + slice_sizes=(1, num_experts), + ).squeeze(0) + local_reshaped = local_tokens_per_expert.reshape( + num_expert_shards, local_expert_size + ) + + # send_sizes[i] = sum of token counts for shard i's experts in our buffer. + send_sizes = jnp.sum(local_reshaped, axis=1) + input_offsets = jnp.concatenate( + [ + jnp.array([0], dtype=send_sizes.dtype), + jnp.cumsum(send_sizes)[:-1], + ] + ) + + # recv_sizes[i] = how many tokens shard i sends to this shard, i.e. the + # sum across our local-expert columns of shard i's row. + local_expert_start = shard_id * local_expert_size + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + recv_sizes = jnp.sum(local_expert_columns, axis=1) + + # output_offsets uses sender-side semantics for ragged_all_to_all: + # output_offsets[j] = row in shard j's buffer where THIS shard's chunk + # should be placed. That's the cumulative sum (over source shards 0..j-1) + # of how many tokens those earlier source shards already sent to shard j. + sends_to_target = jnp.sum( + all_shards_tokens_per_expert.reshape( + num_expert_shards, num_expert_shards, local_expert_size + ), + axis=2, + ) # [src_shard, dst_shard] + zero_row = jnp.zeros((1, num_expert_shards), dtype=sends_to_target.dtype) + cumulated = jnp.cumsum( + jnp.concatenate([zero_row, sends_to_target], axis=0), + axis=0, + dtype=sends_to_target.dtype, + ) # [src_shard + 1, dst_shard]; row r = total sent by sources 0..r-1 + output_offsets = jax.lax.dynamic_slice( + cumulated, + start_indices=(shard_id, 0), + slice_sizes=(1, num_expert_shards), + ).squeeze(0) + + return input_offsets, send_sizes, output_offsets, recv_sizes + + +def compute_reverse_ragged_all_to_all_params( + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Reverse-direction ragged_all_to_all parameters. + + Mirror of :func:`compute_ragged_all_to_all_params` for the **reverse** + EP shuffle that returns expert outputs to their source shards. The + sender / receiver roles are swapped: what we received in the forward + shuffle we now send back, and vice versa. + + Parameters and shapes are identical to + :func:`compute_ragged_all_to_all_params`. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert num_experts % num_expert_shards == 0, ( + f"num_experts={num_experts} must be divisible by num_expert_shards" + f"={num_expert_shards}" + ) + local_expert_size = num_experts // num_expert_shards + + local_expert_start = shard_id * local_expert_size + + # In reverse, what we received becomes what we send. send_sizes[i] is how + # many tokens we send back to source shard i (= what shard i originally + # sent us, summed across our local experts). + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + send_sizes = jnp.sum(local_expert_columns, axis=1) + input_offsets = jnp.concatenate( + [ + jnp.array([0], dtype=send_sizes.dtype), + jnp.cumsum(send_sizes)[:-1], + ] + ) + + # recv_sizes[i] = how many tokens we receive back from shard i (= what + # we originally sent to shard i in the forward). + local_tokens_per_expert = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(shard_id, 0), + slice_sizes=(1, num_experts), + ).squeeze(0) + local_reshaped = local_tokens_per_expert.reshape( + num_expert_shards, local_expert_size + ) + recv_sizes = jnp.sum(local_reshaped, axis=1) + + # output_offsets: the reverse sends-to-target matrix is the transpose of + # the forward one (row i = what shard i sends in reverse = what shard i + # received in forward). Cumsum down source-shard axis, then index our row. + fwd_sends_to = jnp.sum( + all_shards_tokens_per_expert.reshape( + num_expert_shards, num_expert_shards, local_expert_size + ), + axis=2, + ) # forward: [src, dst] + rev_sends_to = jnp.transpose(fwd_sends_to) # reverse: [src, dst] + zero_row = jnp.zeros((1, num_expert_shards), dtype=rev_sends_to.dtype) + rev_cumulated = jnp.cumsum( + jnp.concatenate([zero_row, rev_sends_to], axis=0), + axis=0, + dtype=rev_sends_to.dtype, + ) + output_offsets = jax.lax.dynamic_slice( + rev_cumulated, + start_indices=(shard_id, 0), + slice_sizes=(1, num_expert_shards), + ).squeeze(0) + + return input_offsets, send_sizes, output_offsets, recv_sizes + + +# ----------------------------------------------------------------------------- +# Local permute / unpermute +# ----------------------------------------------------------------------------- +# +# After the forward ragged_all_to_all the receive buffer is laid out as +# ``[from_shard_0_chunk | from_shard_1_chunk | ... ]`` and within each chunk +# tokens are sorted by local-expert id. To feed ``grouped_dense`` we want +# ``[expert_0_block | expert_1_block | ... ]`` where each expert's block +# contains tokens from every source shard. ``local_permute_after_a2a`` +# performs that reorder; ``local_unpermute_before_a2a`` undoes it before the +# reverse ragged_all_to_all. +# +# Implementation uses :func:`sort_chunks_by_index`, which is Triton-backed +# (see ``transformer_engine.jax.triton_extensions.permutation``) and has a +# paired custom-VJP backward. There is no pure-JAX alternative here -- the +# global :func:`unfused_token_dispatch` / :func:`token_dispatch` choice is +# unaffected by this; only the (small) post-A2A chunk reorder uses Triton +# unconditionally. + + +def local_permute_after_a2a( + x_recv: jnp.ndarray, + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, dict]: + """Reorder tokens received via ragged_all_to_all so each local expert's + tokens are contiguous. + + This is the EP-side complement to the global :func:`token_dispatch` / + :func:`unfused_token_dispatch`. Internally uses + :func:`sort_chunks_by_index` (Triton-backed) for both the forward sort + and -- via :func:`local_unpermute_before_a2a` -- the inverse. + + Parameters + ---------- + x_recv : jnp.ndarray + Output of the forward ``ragged_all_to_all`` of shape + ``[buffer_size, hidden_size]``. Layout: source-shard major, then + local-expert id within each source chunk. + all_shards_tokens_per_expert : jnp.ndarray + Per-shard, per-expert token counts of shape + ``[num_expert_shards, num_experts]``. + shard_id : jnp.ndarray + Current EP shard index (typically a traced + :func:`jax.lax.axis_index`). + num_expert_shards : int + Static EP-axis size. + + Returns + ------- + sorted_x : jnp.ndarray + Tokens reordered into expert-major layout. Same shape as ``x_recv``. + local_group_sizes : jnp.ndarray + Per-local-expert token counts of shape ``[local_expert_size]``. + state : dict + Opaque state for :func:`local_unpermute_before_a2a`. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert num_experts % num_expert_shards == 0, ( + f"num_experts={num_experts} must be divisible by num_expert_shards" + f"={num_expert_shards}" + ) + local_expert_size = num_experts // num_expert_shards + local_expert_start = shard_id * local_expert_size + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + + # Flat sizes in source-major order, matching the receive buffer layout: + # [(s0,e0), (s0,e1), ..., (s1,e0), (s1,e1), ...] + split_sizes = local_expert_columns.reshape(-1) + + # Permutation that maps source-major -> expert-major: + # original index = s * E_local + e + # target index = e * num_shards + s + indices_matrix = jnp.arange( + num_expert_shards * local_expert_size, dtype=jnp.int32 + ).reshape(num_expert_shards, local_expert_size) + sorted_chunk_indices = indices_matrix.T.reshape(-1) + + sorted_x, _ = sort_chunks_by_index(x_recv, split_sizes, sorted_chunk_indices) + sorted_split_sizes = split_sizes[sorted_chunk_indices] + inverse_chunk_indices = jnp.argsort(sorted_chunk_indices) + local_group_sizes = jnp.sum(local_expert_columns, axis=0) + state = { + "sorted_split_sizes": sorted_split_sizes, + "inverse_chunk_indices": inverse_chunk_indices, + } + return sorted_x, local_group_sizes, state + + +def local_unpermute_before_a2a( + expert_outputs: jnp.ndarray, + state: dict, +) -> jnp.ndarray: + """Inverse of :func:`local_permute_after_a2a`. + + Parameters + ---------- + expert_outputs : jnp.ndarray + Output of the local expert FFN of shape ``[buffer_size, hidden_size]``, + in expert-major layout. + state : dict + Opaque state returned by :func:`local_permute_after_a2a`. + + Returns + ------- + unsorted_x : jnp.ndarray + Tokens reordered back into source-shard-major layout, ready for the + reverse ``ragged_all_to_all``. Same shape as ``expert_outputs``. + """ + out, _ = sort_chunks_by_index( + expert_outputs, + state["sorted_split_sizes"], + state["inverse_chunk_indices"], + ) + return out diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index bc9a2660b7..5a7c4eb33d 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -53,6 +53,9 @@ def moe_permute_index_map_forward( f"Permute not possible: inp.size(0) ({inp.size(0)}) must match " f"index.size(0) ({index.size(0)})." ) + assert num_out_tokens >= 0, ( + f"moe_permute (index map) requires num_out_tokens >= 0, got {num_out_tokens}." + ) if index.dtype != torch.int32: warnings.warn( f"The data type of the input `index` of Permute is {index.dtype}! " @@ -91,6 +94,10 @@ def _moe_permute_index_map_fake( # pylint: disable=unused-argument """Fake implementation for shape inference.""" num_tokens = inp.shape[0] topK = index.shape[1] + if num_tokens > 0: + assert num_out_tokens >= 0, ( + f"moe_permute (index map) requires num_out_tokens >= 0, got {num_out_tokens}." + ) # Infer output shape output_tokens = num_out_tokens if num_out_tokens > 0 else num_tokens * topK @@ -304,6 +311,10 @@ def moe_permute_mask_map_forward( f"Permute not possible: inp.size(0) ({inp.size(0)}) must match " f"routing_map.size(0) ({routing_map.size(0)})." ) + assert num_out_tokens > 0, ( + f"moe_permute (mask map) requires num_out_tokens > 0, got {num_out_tokens}. " + "Use int(routing_map.sum()) or num_tokens * top_k." + ) num_tokens, hidden_size = inp.size() num_experts = routing_map.size(1) @@ -424,13 +435,26 @@ def _moe_permute_mask_map_forward_fake( # pylint: disable=unused-argument num_tokens = inp.shape[0] hidden_size = inp.shape[1] num_experts = routing_map.shape[1] + if num_tokens > 0: + assert num_out_tokens > 0, ( + f"moe_permute (mask map) requires num_out_tokens > 0, got {num_out_tokens}. " + "Use int(routing_map.sum()) or num_tokens * top_k." + ) + out_rows = num_out_tokens + else: + # Match `moe_permute_mask_map_forward` empty-input fast path (ignores num_out_tokens). + out_rows = 0 # row_id_map: (num_tokens, num_experts * 2 + 1) - fake_output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device=inp.device) + fake_output = torch.empty((out_rows, hidden_size), dtype=inp.dtype, device=inp.device) fake_row_id_map = torch.empty( (num_tokens, num_experts * 2 + 1), dtype=torch.int32, device=inp.device ) if probs is not None: - fake_permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device=inp.device) + fake_permuted_probs = ( + torch.empty((out_rows,), dtype=probs.dtype, device=inp.device) + if out_rows > 0 + else torch.empty(0, device=inp.device) + ) else: fake_permuted_probs = torch.empty(0, device=inp.device) return fake_output, fake_row_id_map, fake_permuted_probs @@ -852,7 +876,7 @@ def _moe_unpermute_mask_map_backward_wrapper(ctx, unpermuted_act_grad): def moe_permute( inp: torch.Tensor, routing_map: torch.Tensor, - num_out_tokens: int = -1, + num_out_tokens: int, max_token_num: int = -1, map_type: str = "mask", ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -871,13 +895,13 @@ def moe_permute( The values in it: 1 means the token is routed to this expert and 0 means not. If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'. The values in it are the routed expert indices. - num_out_tokens : int, default = -1 - The effective output token count, representing the number of tokens not dropped. - By default, set to '-1', meaning no tokens are dropped. + num_out_tokens : int + Number of output tokens (rows in the permuted buffer). + mask map: must be > 0, e.g. int(routing_map.sum()) or num_tokens * top_k. + index map: must be >= 0; 0 means infer as num_tokens * top_k. max_token_num : int, default = -1 - The maximum number of tokens, used for workspace allocation. - By default, set to '-1', meaning the calculation of the size of workspace is - automatically taken over by the operator. + Workspace sizing hint, only used for map_type='index'. Ignored for 'mask'. + map_type : str, default = 'mask' Type of the routing map tensor. Options are: 'mask', 'index'. @@ -902,7 +926,7 @@ def moe_permute_with_probs( inp: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor, - num_out_tokens: int = -1, + num_out_tokens: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Permute the tokens and probs based on the routing_map. @@ -921,9 +945,9 @@ def moe_permute_with_probs( routing_map : torch.Tensor The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'. The values in it: 1 means the token is routed to this expert and 0 means not. - num_out_tokens : int, default = -1 - The effective output token count, representing the number of tokens not dropped. - By default, set to '-1', meaning no tokens are dropped. + num_out_tokens : int + Number of output tokens (rows in the permuted buffer). Must be > 0, + e.g. int(routing_map.sum()) or num_tokens * top_k. """ if isinstance(inp, QuantizedTensor) and torch.compiler.is_compiling(): raise RuntimeError( diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 4902bc686c..c155d73e1e 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -151,7 +151,7 @@ def permute_with_mask_map( num_experts : int Number of experts in the input tensor. num_out_tokens : int - Number of tokens in the permuted tensor. + Number of rows allocated for the permuted tensor (must be a positive integer). hidden_size : int Hidden size of the input tensor. scale_hidden_dim : int