From af29c7c401b605cb1f89ef9efd9203b526843507 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 13 Apr 2026 17:11:24 -0700 Subject: [PATCH 1/6] Change docs, and guard against invalid num_out_tokens in mask_map code path Signed-off-by: tdophung --- transformer_engine/jax/permutation.py | 11 ++-- transformer_engine/pytorch/permutation.py | 50 ++++++++++++++----- .../pytorch/triton/permutation.py | 2 +- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 6a0a3229d9..81972aac0f 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -73,9 +73,7 @@ def token_dispatch( Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. Values: 1 = routed, 0 = not routed. num_out_tokens : int - The number of output tokens after permutation (before padding). For the dropless - case, this should be equal to the sum of routing_map. Must be provided explicitly - for JIT compatibility since output shape must be known at compile time. + Number of output tokens (rows in the permuted buffer, before padding). Must be > 0, e.g. int(jnp.sum(routing_map)) or num_tokens * top_k. Must be a compile-time constant for JIT. probs : Optional[jnp.ndarray] Optional routing probabilities of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. If provided, permuted_probs will be returned. @@ -121,6 +119,8 @@ def token_dispatch( ((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size This accounts for the maximum possible padding when each expert needs (align_size - 1) extra tokens to align, rounded down to align_size for buffer alignment. + + Non-positive num_out_tokens (e.g. -1) raises AssertionError. """ use_padding = align_size is not None num_experts = routing_map.shape[-1] @@ -134,6 +134,11 @@ def token_dispatch( else: worst_case_out_tokens = num_out_tokens + assert num_out_tokens > 0, ( + f"token_dispatch requires num_out_tokens > 0, got {num_out_tokens}. " + "Use int(jnp.sum(routing_map)) or num_tokens * top_k." + ) + return _token_dispatch( inp, routing_map, probs, num_out_tokens, worst_case_out_tokens, align_size, use_padding ) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index bc9a2660b7..5a7c4eb33d 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -53,6 +53,9 @@ def moe_permute_index_map_forward( f"Permute not possible: inp.size(0) ({inp.size(0)}) must match " f"index.size(0) ({index.size(0)})." ) + assert num_out_tokens >= 0, ( + f"moe_permute (index map) requires num_out_tokens >= 0, got {num_out_tokens}." + ) if index.dtype != torch.int32: warnings.warn( f"The data type of the input `index` of Permute is {index.dtype}! " @@ -91,6 +94,10 @@ def _moe_permute_index_map_fake( # pylint: disable=unused-argument """Fake implementation for shape inference.""" num_tokens = inp.shape[0] topK = index.shape[1] + if num_tokens > 0: + assert num_out_tokens >= 0, ( + f"moe_permute (index map) requires num_out_tokens >= 0, got {num_out_tokens}." + ) # Infer output shape output_tokens = num_out_tokens if num_out_tokens > 0 else num_tokens * topK @@ -304,6 +311,10 @@ def moe_permute_mask_map_forward( f"Permute not possible: inp.size(0) ({inp.size(0)}) must match " f"routing_map.size(0) ({routing_map.size(0)})." ) + assert num_out_tokens > 0, ( + f"moe_permute (mask map) requires num_out_tokens > 0, got {num_out_tokens}. " + "Use int(routing_map.sum()) or num_tokens * top_k." + ) num_tokens, hidden_size = inp.size() num_experts = routing_map.size(1) @@ -424,13 +435,26 @@ def _moe_permute_mask_map_forward_fake( # pylint: disable=unused-argument num_tokens = inp.shape[0] hidden_size = inp.shape[1] num_experts = routing_map.shape[1] + if num_tokens > 0: + assert num_out_tokens > 0, ( + f"moe_permute (mask map) requires num_out_tokens > 0, got {num_out_tokens}. " + "Use int(routing_map.sum()) or num_tokens * top_k." + ) + out_rows = num_out_tokens + else: + # Match `moe_permute_mask_map_forward` empty-input fast path (ignores num_out_tokens). + out_rows = 0 # row_id_map: (num_tokens, num_experts * 2 + 1) - fake_output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device=inp.device) + fake_output = torch.empty((out_rows, hidden_size), dtype=inp.dtype, device=inp.device) fake_row_id_map = torch.empty( (num_tokens, num_experts * 2 + 1), dtype=torch.int32, device=inp.device ) if probs is not None: - fake_permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device=inp.device) + fake_permuted_probs = ( + torch.empty((out_rows,), dtype=probs.dtype, device=inp.device) + if out_rows > 0 + else torch.empty(0, device=inp.device) + ) else: fake_permuted_probs = torch.empty(0, device=inp.device) return fake_output, fake_row_id_map, fake_permuted_probs @@ -852,7 +876,7 @@ def _moe_unpermute_mask_map_backward_wrapper(ctx, unpermuted_act_grad): def moe_permute( inp: torch.Tensor, routing_map: torch.Tensor, - num_out_tokens: int = -1, + num_out_tokens: int, max_token_num: int = -1, map_type: str = "mask", ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -871,13 +895,13 @@ def moe_permute( The values in it: 1 means the token is routed to this expert and 0 means not. If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'. The values in it are the routed expert indices. - num_out_tokens : int, default = -1 - The effective output token count, representing the number of tokens not dropped. - By default, set to '-1', meaning no tokens are dropped. + num_out_tokens : int + Number of output tokens (rows in the permuted buffer). + mask map: must be > 0, e.g. int(routing_map.sum()) or num_tokens * top_k. + index map: must be >= 0; 0 means infer as num_tokens * top_k. max_token_num : int, default = -1 - The maximum number of tokens, used for workspace allocation. - By default, set to '-1', meaning the calculation of the size of workspace is - automatically taken over by the operator. + Workspace sizing hint, only used for map_type='index'. Ignored for 'mask'. + map_type : str, default = 'mask' Type of the routing map tensor. Options are: 'mask', 'index'. @@ -902,7 +926,7 @@ def moe_permute_with_probs( inp: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor, - num_out_tokens: int = -1, + num_out_tokens: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Permute the tokens and probs based on the routing_map. @@ -921,9 +945,9 @@ def moe_permute_with_probs( routing_map : torch.Tensor The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'. The values in it: 1 means the token is routed to this expert and 0 means not. - num_out_tokens : int, default = -1 - The effective output token count, representing the number of tokens not dropped. - By default, set to '-1', meaning no tokens are dropped. + num_out_tokens : int + Number of output tokens (rows in the permuted buffer). Must be > 0, + e.g. int(routing_map.sum()) or num_tokens * top_k. """ if isinstance(inp, QuantizedTensor) and torch.compiler.is_compiling(): raise RuntimeError( diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 4902bc686c..c155d73e1e 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -151,7 +151,7 @@ def permute_with_mask_map( num_experts : int Number of experts in the input tensor. num_out_tokens : int - Number of tokens in the permuted tensor. + Number of rows allocated for the permuted tensor (must be a positive integer). hidden_size : int Hidden size of the input tensor. scale_hidden_dim : int From 7215415644d2f39f9587a56b8b75e21ce2946398 Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 21 Apr 2026 11:57:08 -0700 Subject: [PATCH 2/6] initial impl Signed-off-by: tdophung --- tests/jax/test_moe_block.py | 292 ++++++++ transformer_engine/jax/flax/__init__.py | 2 + transformer_engine/jax/flax/moe.py | 890 +++++++++++++++++++++++ transformer_engine/jax/mt_permutation.py | 356 +++++++++ 4 files changed, 1540 insertions(+) create mode 100644 tests/jax/test_moe_block.py create mode 100644 transformer_engine/jax/flax/moe.py create mode 100644 transformer_engine/jax/mt_permutation.py diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py new file mode 100644 index 0000000000..458d674c7d --- /dev/null +++ b/tests/jax/test_moe_block.py @@ -0,0 +1,292 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Basic tests for ``transformer_engine.jax.flax.MoEBlock``. + +These tests exercise the MoEBlock on a single device (no expert parallelism) +and verify: + +* Forward pass runs end-to-end and produces the expected output shape. +* Backward pass yields finite, non-trivial parameter gradients. +* The two permutation backends (``"pure_jax"`` and ``"triton"``) produce + numerically equivalent outputs and gradients when given the same routing + decisions. +* Auxiliary load-balancing loss is returned when ``aux_loss_coeff > 0``. +* DeepSeek-style grouped top-k (``num_groups`` / ``group_topk``) runs. +* ``align_size > 0`` produces numerically-equivalent outputs to ``align_size = 0`` + for the pure-JAX backend (padding must not change the result). +""" + +import sys +from typing import Tuple + +import jax +import jax.numpy as jnp +import pytest + + +# The MoEBlock pulls in both the fused-router CUDA kernel and the Triton +# permutation kernels, so it can only run in the environment where those are +# available. We gate the test on the ``triton`` marker (the Triton permutation +# backend is stricter than the CUDA router). See ``conftest.py``. + + +@pytest.fixture(autouse=True, scope="function") +def _inject_moe(request): + """Lazy-load ``MoEBlock`` only for tests marked ``triton``.""" + if not request.node.get_closest_marker("triton"): + yield + return + + from transformer_engine.jax.flax import MoEBlock + + mod = sys.modules[__name__] + mod.MoEBlock = MoEBlock + yield + + +# ----------------------------------------------------------------------------- +# Configurations +# ----------------------------------------------------------------------------- +# +# Keep shapes small so the tests are cheap but still exercise every code path. + +DTYPE = jnp.bfloat16 +BATCH_SIZE = 2 +SEQUENCE_LENGTH = 16 +HIDDEN_SIZE = 64 +INTERMEDIATE_SIZE = 128 +NUM_EXPERTS = 8 +NUM_EXPERTS_PER_TOK = 2 + + +def _make_inputs( + key: jax.Array, batch_size: int = BATCH_SIZE, sequence_length: int = SEQUENCE_LENGTH +) -> jax.Array: + return jax.random.normal( + key, (batch_size, sequence_length, HIDDEN_SIZE), dtype=DTYPE + ) + + +def _init_and_apply( + block, + inputs: jax.Array, + init_key: jax.Array, +) -> Tuple[dict, jax.Array, jax.Array]: + variables = block.init(init_key, inputs) + output, aux_loss = block.apply(variables, inputs) + return variables, output, aux_loss + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + + +@pytest.mark.triton +class TestMoEBlockSingleDevice: + """Single-device smoke tests for :class:`MoEBlock`.""" + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_forward_shape_and_finite(self, permutation_backend): + key = jax.random.PRNGKey(0) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, aux_loss = _init_and_apply(block, inputs, init_key) + + assert output.shape == inputs.shape, ( + f"Unexpected output shape {output.shape} for backend {permutation_backend}" + ) + assert output.dtype == inputs.dtype + assert jnp.all(jnp.isfinite(output)), "Output contains NaN/Inf" + assert aux_loss is None, "aux_loss should be None when aux_loss_coeff=0" + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_backward_grad(self, permutation_backend): + key = jax.random.PRNGKey(1) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + variables = block.init(init_key, inputs) + + def loss_fn(variables, inputs): + output, _ = block.apply(variables, inputs) + return jnp.mean(output.astype(jnp.float32) ** 2) + + grads = jax.grad(loss_fn)(variables, inputs) + # All trainable kernels should receive a non-trivial gradient. + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g = grads["params"][name] + assert jnp.all(jnp.isfinite(g)), f"{name} gradient has NaN/Inf" + assert jnp.any(g != 0.0), f"{name} gradient is identically zero" + + def test_pure_jax_triton_equivalence(self): + """Both permutation backends must produce the same forward + grads + under identical routing decisions. + + Since the two backends share the same routing path (TE's fused + top-k), fixing the gate kernel gives both the same routing decisions + and the remainder of the network is identical modulo the permutation + implementation, whose semantics are equivalent. + """ + key = jax.random.PRNGKey(2) + init_key, data_key = jax.random.split(key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + dtype=DTYPE, + ) + pure_block = MoEBlock(permutation_backend="pure_jax", **base_kwargs) + triton_block = MoEBlock(permutation_backend="triton", **base_kwargs) + inputs = _make_inputs(data_key) + + # Share a single parameter tree so routing decisions and expert + # weights are identical for both backends. + variables = pure_block.init(init_key, inputs) + + def loss_fn(block, variables, inputs): + output, _ = block.apply(variables, inputs) + return jnp.mean(output.astype(jnp.float32) ** 2), output + + (loss_pj, out_pj), grads_pj = jax.value_and_grad( + loss_fn, argnums=1, has_aux=True + )(pure_block, variables, inputs) + (loss_tr, out_tr), grads_tr = jax.value_and_grad( + loss_fn, argnums=1, has_aux=True + )(triton_block, variables, inputs) + + # BF16 tolerances: outputs come out of the grouped-GEMM + weighted + # sum so they accumulate error; we use ~2 ULPs worth of slack. + atol_out, rtol_out = 5e-2, 5e-2 + assert jnp.allclose(out_pj, out_tr, atol=atol_out, rtol=rtol_out), ( + f"Forward outputs differ across backends: max diff" + f" {jnp.max(jnp.abs(out_pj - out_tr))}" + ) + assert jnp.allclose(loss_pj, loss_tr, atol=atol_out, rtol=rtol_out) + + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g_pj = grads_pj["params"][name] + g_tr = grads_tr["params"][name] + assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), ( + f"Gradient for {name} differs across backends: max diff" + f" {jnp.max(jnp.abs(g_pj - g_tr))}" + ) + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_aux_loss_returned(self, permutation_backend): + key = jax.random.PRNGKey(3) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + aux_loss_coeff=1e-2, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, aux_loss = _init_and_apply(block, inputs, init_key) + + assert output.shape == inputs.shape + assert aux_loss is not None, "aux_loss should be returned when coeff > 0" + assert aux_loss.shape == (), "aux_loss should be a scalar" + assert jnp.isfinite(aux_loss) + # With uniform-ish routing the loss should be small-positive, not huge. + assert jnp.abs(aux_loss) < 1e2 + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_group_topk_deepseek(self, permutation_backend): + """Exercise DeepSeek-style grouped top-k routing.""" + key = jax.random.PRNGKey(4) + init_key, data_key = jax.random.split(key) + + # num_groups must divide num_experts. + num_groups = 4 + group_topk = 2 + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + score_function="sigmoid", + num_groups=num_groups, + group_topk=group_topk, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, _aux_loss = _init_and_apply(block, inputs, init_key) + + assert output.shape == inputs.shape + assert jnp.all(jnp.isfinite(output)) + + def test_align_size_equivalence_pure_jax(self): + """For the pure-JAX backend, ``align_size > 0`` must not change the + numerical output of the forward pass: padding tokens contribute zero + to every expert GEMM output (their input rows are zeros) and are + stripped before the weighted sum. + """ + key = jax.random.PRNGKey(5) + init_key, data_key = jax.random.split(key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend="pure_jax", + dtype=DTYPE, + ) + block_no_pad = MoEBlock(align_size=0, **base_kwargs) + block_pad = MoEBlock(align_size=16, **base_kwargs) + inputs = _make_inputs(data_key) + variables = block_no_pad.init(init_key, inputs) + + out_no_pad, _ = block_no_pad.apply(variables, inputs) + out_pad, _ = block_pad.apply(variables, inputs) + assert jnp.allclose(out_no_pad, out_pad, atol=5e-2, rtol=5e-2), ( + "align_size > 0 must not change pure_jax forward output; max diff" + f" {jnp.max(jnp.abs(out_no_pad - out_pad))}" + ) + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_jit_and_determinism(self, permutation_backend): + """The block must be JIT-compilable and produce a deterministic + forward pass across repeat calls with the same params.""" + key = jax.random.PRNGKey(6) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + variables = block.init(init_key, inputs) + + @jax.jit + def forward(variables, inputs): + return block.apply(variables, inputs)[0] + + out_a = forward(variables, inputs) + out_b = forward(variables, inputs) + assert jnp.array_equal(out_a, out_b), "JITted forward is non-deterministic" diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index 92a968f061..0cd7835bcf 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -9,6 +9,7 @@ make_dot_general_cls, make_grouped_dense_cls, ) +from .moe import MoEBlock from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -18,6 +19,7 @@ "LayerNorm", "LayerNormDenseGeneral", "LayerNormMLP", + "MoEBlock", "wrap_function_in_te_state_module", "make_dot_general_cls", "make_grouped_dense_cls", diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py new file mode 100644 index 0000000000..ddbe687771 --- /dev/null +++ b/transformer_engine/jax/flax/moe.py @@ -0,0 +1,890 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Flax Linen MoEBlock for TransformerEngine JAX. + +This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE layer +that wires together TE's fused router, a selectable token-dispatch backend +(pure-JAX MaxText-style or Triton), TE's ``grouped_dense``, and optional +ring-of-experts Expert Parallelism. + +See ``plans/te_jax_moeblock_926b7994.plan.md`` for the full design rationale +and the mapping to Maxtext's ``RoutedMoE``. +""" + +from typing import Any, Callable, NewType, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from flax import linen as nn +from jax.sharding import PartitionSpec as P + +from ..dense import grouped_dense +from ..mt_permutation import mt_token_combine, mt_token_dispatch +from ..permutation import token_combine, token_dispatch +from ..quantize import noop_quantizer_set +from ..router import ScoreFunction, fused_moe_aux_loss, fused_topk_with_score_function +from ..sharding import with_sharding_constraint_by_logical_axes +from .module import TransformerEngineBase + +PRNGKey = Any +Shape = Tuple[int, ...] +DType = NewType("DType", jnp.dtype) +Array = NewType("Array", jnp.ndarray) +Initializer = Callable[[PRNGKey, Shape, DType], Array] + + +__all__ = ["MoEBlock"] + + +# ============================================================================= +# Helpers +# ============================================================================= + + +_ACTIVATIONS = { + "silu": jax.nn.silu, + "swish": jax.nn.silu, + "gelu": jax.nn.gelu, + "relu": jax.nn.relu, + "identity": lambda x: x, + "linear": lambda x: x, +} + + +def _get_activation_fn(name: str) -> Callable: + key = name.lower() + if key not in _ACTIVATIONS: + raise ValueError( + f"Unsupported activation_type={name!r}; supported: {sorted(_ACTIVATIONS)}" + ) + return _ACTIVATIONS[key] + + +def _extract_topk_from_routing_map( + sparse_probs: jnp.ndarray, + routing_map: jnp.ndarray, + topk: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Convert TE's ``(sparse_probs, routing_map)`` to ``(selected_experts, weights)``. + + ``routing_map`` is a boolean mask of shape ``[num_tokens, num_experts]`` + with exactly ``topk`` ``True`` positions per row. ``sparse_probs`` is the + same-shape float tensor whose non-zero entries are the routing weights. + + The per-token top-k expert IDs are recovered as the last ``topk`` indices + of ``argsort(routing_map)`` (``False < True``), and the corresponding + weights are gathered from ``sparse_probs`` along the expert axis. + + The within-row expert ordering does not have to match the router's + top-k ordering: :func:`mt_token_dispatch` and :func:`mt_token_combine` + only require that ``selected_experts`` and ``weights`` are consistent with + each other. + """ + # Cast to int32 so argsort has a well-defined ordering. (Ascending argsort + # on 0/1 puts the ``True`` positions last; we then slice the last ``topk``.) + selected_experts = jnp.argsort(routing_map.astype(jnp.int32), axis=-1)[:, -topk:] + weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) + return selected_experts, weights + + +# ============================================================================= +# MoEBlock +# ============================================================================= + + +class MoEBlock(TransformerEngineBase): + """Mixture-of-Experts Flax Linen block. + + Encapsulates the full MoE forward pass: gate projection, fused top-k + routing, optional auxiliary load-balancing loss, token dispatch, per-expert + two-layer FFN via grouped GEMMs, activation, token combine, and optional + ring-of-experts expert parallelism. + + The permutation step is pluggable: the default ``permutation_backend="pure_jax"`` + uses the MaxText-style argsort-based dispatch/combine in + :mod:`transformer_engine.jax.mt_permutation`, which empirically outperforms + the Triton kernels on several E2E workloads. ``permutation_backend="triton"`` + uses TE's ``token_dispatch`` / ``token_combine`` kernels. + + Parameters + ---------- + num_experts : int + Total number of experts. + num_experts_per_tok : int + Top-k value (number of experts each token is routed to). + intermediate_size : int + Per-expert FFN hidden dim. + + activation_type : str + FFN activation applied to the gate projection. Paired with the up + projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. Supported: + ``"silu"``/``"swish"`` (default), ``"gelu"``, ``"relu"``, + ``"identity"``/``"linear"``. + + score_function : str or ScoreFunction + ``"softmax"`` (default) or ``"sigmoid"`` for :func:`fused_topk_with_score_function`. + use_pre_softmax : bool + Apply softmax before top-k when ``score_function="softmax"``. + num_groups : int + Number of routing groups for grouped top-k (DeepSeek). ``<=0`` disables. + group_topk : int + Top-k at the group level. ``<=0`` disables. + scaling_factor : float + Scaling factor applied to output probs. + use_expert_bias : bool + If ``True``, registers a learnable ``expert_bias`` parameter of shape + ``[num_experts]`` and passes it to the fused router. Only valid with + ``score_function="sigmoid"`` (DeepSeek V3 loss-free load balancing). + aux_loss_coeff : float + If ``> 0``, compute and return the MoE auxiliary load-balancing loss + scalar via :func:`fused_moe_aux_loss`. ``0`` disables. + + gate_kernel_axes : tuple[str, ...] + Logical partitioning axes for the gate kernel of shape + ``[hidden, num_experts]``. + wi_kernel_axes : tuple[str, ...] + Logical partitioning axes for the ``wi_0`` and ``wi_1`` kernels of + shape ``[num_experts, hidden, intermediate]``. Default: + ``("exp", "embed", "mlp")``. + wo_kernel_axes : tuple[str, ...] + Logical partitioning axes for the ``wo`` kernel of shape + ``[num_experts, intermediate, hidden]``. Default: + ``("exp", "mlp", "embed")``. + input_axes : tuple[str, ...] + Logical axes used to constrain the input activation sharding at the + block boundary. ``()`` (default) means no constraint. + + expert_parallelism_axis : Optional[str] + Mesh axis along which experts are split. When set, the forward pass + is wrapped in :func:`jax.experimental.shard_map.shard_map` that + implements the ring-of-experts EP strategy: ``all_gather`` on inputs + and gate logits, local routing + dispatch + FFN + combine, then + ``psum_scatter`` on the output. When ``None`` (default), no + ``shard_map`` wrapper is used; each primitive's ``custom_partitioning`` + rule handles DP/FSDP/TP automatically. + tensor_parallelism_axis : Optional[str] + Mesh axis for tensor parallelism on the FFN intermediate dim. When + set, the output of the ``wo`` grouped GEMM is ``psum_scatter`` ed + along this axis (inside the ``shard_map`` when EP is enabled, else at + the end of the forward pass). + + permutation_backend : str + ``"pure_jax"`` (default; faster on many E2E workloads) or ``"triton"``. + align_size : int + Alignment for per-expert group sizes after padding. ``0`` disables + padding (faster for the unquantized path). ``>0`` is required for + quantized TE grouped GEMM whose recipe-specific alignment must divide + ``align_size``. Passed through to both permutation backends. + use_custom_sort_vjp : bool + Only used when ``permutation_backend="pure_jax"``. If ``True``, uses + a custom VJP for the argsort-based gather (faster in most cases). + + dtype : jnp.dtype + Compute and parameter dtype. + kernel_init : Initializer + Initializer for all kernels. Defaults to ``variance_scaling(1.0, + 'fan_in', 'truncated_normal')``. + use_bias : bool + If ``True``, registers per-expert FFN biases ``wi_0_bias``, + ``wi_1_bias``, ``wo_bias``. + """ + + # Architecture + num_experts: int = 8 + num_experts_per_tok: int = 2 + intermediate_size: int = 2048 + activation_type: str = "silu" + + # Routing + score_function: Union[str, ScoreFunction] = "softmax" + use_pre_softmax: bool = False + num_groups: int = -1 + group_topk: int = -1 + scaling_factor: float = 1.0 + use_expert_bias: bool = False + aux_loss_coeff: float = 0.0 + + # Sharding + gate_kernel_axes: Tuple[Optional[str], ...] = () + wi_kernel_axes: Tuple[Optional[str], ...] = ("exp", "embed", "mlp") + wo_kernel_axes: Tuple[Optional[str], ...] = ("exp", "mlp", "embed") + input_axes: Tuple[Optional[str], ...] = () + + # Parallelism + expert_parallelism_axis: Optional[str] = None + tensor_parallelism_axis: Optional[str] = None + # ``jax.sharding.Mesh`` to use when ``expert_parallelism_axis`` is set. + # Required for the ``shard_map`` wrapper; ignored otherwise. + mesh: Optional[Any] = None + + # Permutation + permutation_backend: str = "pure_jax" + align_size: int = 0 + use_custom_sort_vjp: bool = True + + # Dtypes / init / misc + dtype: DType = jnp.float32 + kernel_init: Optional[Initializer] = None + bias_init: Initializer = nn.initializers.zeros + expert_bias_init: Initializer = nn.initializers.zeros + use_bias: bool = False + + def __post_init__(self): + if self.kernel_init is None: + object.__setattr__( + self, + "kernel_init", + nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ), + ) + if self.permutation_backend not in ("pure_jax", "triton"): + raise ValueError( + "permutation_backend must be 'pure_jax' or 'triton'," + f" got {self.permutation_backend!r}" + ) + if self.use_expert_bias: + # ``fused_topk_with_score_function`` only accepts ``expert_bias`` + # under the sigmoid score function. Raise early to surface the + # misconfiguration instead of failing deep inside the kernel. + score_func = ( + self.score_function.name.lower() + if isinstance(self.score_function, ScoreFunction) + else str(self.score_function).lower() + ) + if score_func != "sigmoid": + raise ValueError( + "use_expert_bias=True requires score_function='sigmoid';" + f" got {self.score_function!r}." + ) + super().__post_init__() + + # ------------------------------------------------------------------ + # Parameter registration + # ------------------------------------------------------------------ + + def _make_params(self, hidden_size: int): + """Register module parameters and return them as a dict.""" + gate_kernel = self.param( + "gate_kernel", + nn.with_logical_partitioning(self.kernel_init, self.gate_kernel_axes), + (hidden_size, self.num_experts), + self.dtype, + ) + wi_0 = self.param( + "wi_0", + nn.with_logical_partitioning(self.kernel_init, self.wi_kernel_axes), + (self.num_experts, hidden_size, self.intermediate_size), + self.dtype, + ) + wi_1 = self.param( + "wi_1", + nn.with_logical_partitioning(self.kernel_init, self.wi_kernel_axes), + (self.num_experts, hidden_size, self.intermediate_size), + self.dtype, + ) + wo = self.param( + "wo", + nn.with_logical_partitioning(self.kernel_init, self.wo_kernel_axes), + (self.num_experts, self.intermediate_size, hidden_size), + self.dtype, + ) + params = { + "gate_kernel": gate_kernel, + "wi_0": wi_0, + "wi_1": wi_1, + "wo": wo, + } + if self.use_bias: + params["wi_0_bias"] = self.param( + "wi_0_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "mlp")), + (self.num_experts, self.intermediate_size), + self.dtype, + ) + params["wi_1_bias"] = self.param( + "wi_1_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "mlp")), + (self.num_experts, self.intermediate_size), + self.dtype, + ) + params["wo_bias"] = self.param( + "wo_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "embed")), + (self.num_experts, hidden_size), + self.dtype, + ) + if self.use_expert_bias: + params["expert_bias"] = self.param( + "expert_bias", + nn.with_logical_partitioning(self.expert_bias_init, ("exp",)), + (self.num_experts,), + self.dtype, + ) + return params + + # ------------------------------------------------------------------ + # Entry point + # ------------------------------------------------------------------ + + @nn.compact + def __call__( + self, + inputs: Array, + deterministic: bool = True, + ) -> Tuple[Array, Optional[Array]]: + """Run the MoE forward pass. + + Parameters + ---------- + inputs : jnp.ndarray + Input tensor of shape ``[batch, sequence, hidden]``. + deterministic : bool + Reserved for future dropout-based routing; currently unused. + + Returns + ------- + output : jnp.ndarray + Output tensor of shape ``[batch, sequence, hidden]``. + aux_loss : Optional[jnp.ndarray] + Scalar auxiliary load-balancing loss when ``aux_loss_coeff > 0``, + else ``None``. + """ + del deterministic # unused for now + + assert inputs.ndim == 3, ( + f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" + ) + inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) + + batch_size, sequence_length, hidden_size = inputs.shape + params = self._make_params(hidden_size) + + # Gate projection runs OUTSIDE the EP shard_map (mirroring Maxtext), + # so that each EP shard projects its own local slice of tokens and we + # later all-gather only the logits, not the full inputs. + gate_logits = self._gate(inputs, params["gate_kernel"]) + + if self.expert_parallelism_axis is not None: + return self._forward_ring_ep(inputs, gate_logits, params) + return self._forward_single_shard(inputs, gate_logits, params) + + # ------------------------------------------------------------------ + # Gate + # ------------------------------------------------------------------ + + def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray: + """Linear gate projection ``inputs @ gate_kernel``. + + Kept as a plain matmul (not ``DenseGeneral``) so it integrates cleanly + with the EP shard_map below: the gate matmul runs in the outer + (pre-shard_map) scope and its output is all-gathered along the EP axis + inside the shard_map. + """ + # Cast kernel to input dtype outside FP8 scope (gate is typically BF16/FP32). + kernel = gate_kernel.astype(inputs.dtype) + return jnp.einsum("bsh,he->bse", inputs, kernel) + + # ------------------------------------------------------------------ + # Single-shard (no EP) forward + # ------------------------------------------------------------------ + + def _forward_single_shard( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + params: dict, + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + batch_size, sequence_length, hidden_size = inputs.shape + + inputs_2d = inputs.reshape(-1, hidden_size) + logits_2d = gate_logits.reshape(-1, self.num_experts) + + sparse_probs, routing_map, aux_loss = self._route( + logits_2d, params.get("expert_bias") + ) + + expert_outputs, combine_state = self._dispatch_and_expert_ffn( + inputs_2d, + sparse_probs, + routing_map, + params, + num_experts_local=self.num_experts, + roll_to_expert_id=None, + local_tokens_per_expert_count=self.num_experts, + ) + + output = self._combine( + expert_outputs, + combine_state, + batch_size=batch_size, + sequence_length=sequence_length, + ) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + + return output, aux_loss + + # ------------------------------------------------------------------ + # Ring-of-Experts EP forward + # ------------------------------------------------------------------ + + def _forward_ring_ep( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + params: dict, + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """Wrap the dispatch / FFN / combine pipeline in a ring-of-experts + ``shard_map``. + + Inside the shard_map each EP shard: + 1. ``all_gather`` s the inputs and logits along the EP axis so it + sees every token globally. + 2. Routes with ``roll_to_expert_id = num_experts_per_shard * shard_id`` + so its local experts are in slots ``[0, num_experts_per_shard)``. + 3. Dispatches tokens, slicing ``group_sizes`` to the first + ``num_experts_per_shard`` entries (the rest correspond to remote + experts and should be zero after the roll/mask). + 4. Runs the per-expert FFN on its local expert slice of + ``wi_0`` / ``wi_1`` / ``wo``. + 5. Combines at the expanded-batch shape ``[B * num_ep, S, H]`` then + ``psum_scatter`` s along the EP axis to return the local slice. + """ + from jax.experimental.shard_map import shard_map + + ep_axis = self.expert_parallelism_axis + if self.mesh is None: + raise ValueError( + "MoEBlock.expert_parallelism_axis is set; `mesh` must also be" + " provided so the ring-of-experts shard_map can be built." + ) + mesh = self.mesh + num_ep = mesh.shape[ep_axis] + assert self.num_experts % num_ep == 0, ( + f"num_experts={self.num_experts} must be divisible by EP size={num_ep}" + ) + num_experts_per_shard = self.num_experts // num_ep + + # in_specs / out_specs use PartitionSpec over the EP axis for inputs/ + # outputs (leading batch dim is split across EP) and ``P("exp", ...)`` + # for the expert weights, where we require the user's logical axis + # rules to map ``"exp"`` to the EP mesh axis. The expert bias is + # similarly sharded along the expert axis. + inputs_spec = P(ep_axis, None, None) + logits_spec = P(ep_axis, None, None) + wi_spec = P(ep_axis, None, None) + wo_spec = P(ep_axis, None, None) + output_spec = P(ep_axis, None, None) + scalar_spec = P() + bias_1d_spec = P(ep_axis) + bias_2d_spec = P(ep_axis, None) + + expert_bias_value = params.get("expert_bias") + wi_0_bias_value = params.get("wi_0_bias") + wi_1_bias_value = params.get("wi_1_bias") + wo_bias_value = params.get("wo_bias") + + in_specs = [ + inputs_spec, + logits_spec, + wi_spec, + wi_spec, + wo_spec, + ] + captured = [ + inputs, + gate_logits, + params["wi_0"], + params["wi_1"], + params["wo"], + ] + if expert_bias_value is not None: + in_specs.append(bias_1d_spec) + captured.append(expert_bias_value) + if wi_0_bias_value is not None: + in_specs.extend([bias_2d_spec, bias_2d_spec, bias_2d_spec]) + captured.extend([wi_0_bias_value, wi_1_bias_value, wo_bias_value]) + + out_specs = (output_spec, scalar_spec) + + use_expert_bias = expert_bias_value is not None + use_bias = wi_0_bias_value is not None + + def _ring_fn(*args): + idx = 0 + local_inputs = args[idx]; idx += 1 + local_gate_logits = args[idx]; idx += 1 + local_wi_0 = args[idx]; idx += 1 + local_wi_1 = args[idx]; idx += 1 + local_wo = args[idx]; idx += 1 + local_expert_bias = None + if use_expert_bias: + local_expert_bias = args[idx]; idx += 1 + local_wi_0_bias = local_wi_1_bias = local_wo_bias = None + if use_bias: + local_wi_0_bias = args[idx]; idx += 1 + local_wi_1_bias = args[idx]; idx += 1 + local_wo_bias = args[idx]; idx += 1 + + shard_id = jax.lax.axis_index(ep_axis) + + # All-gather inputs and logits along the EP axis so each shard + # sees the global tokens. + gathered_inputs = jax.lax.all_gather( + local_inputs, axis_name=ep_axis, tiled=True + ) + gathered_logits = jax.lax.all_gather( + local_gate_logits, axis_name=ep_axis, tiled=True + ) + + # If the user also sharded by EP on the expert_bias, ``local_expert_bias`` + # is already the local slice; the router operates over the full + # expert axis, so all-gather to reconstruct. + global_expert_bias = None + if local_expert_bias is not None: + global_expert_bias = jax.lax.all_gather( + local_expert_bias, axis_name=ep_axis, tiled=True + ) + + batch_size = gathered_inputs.shape[0] + sequence_length = gathered_inputs.shape[1] + hidden_size = gathered_inputs.shape[2] + + inputs_2d = gathered_inputs.reshape(-1, hidden_size) + logits_2d = gathered_logits.reshape(-1, self.num_experts) + + sparse_probs, routing_map, aux_loss = self._route( + logits_2d, global_expert_bias + ) + + # Ring-of-experts roll: after rolling expert columns by + # ``-num_experts_per_shard * shard_id``, this shard's experts + # occupy slots ``[0, num_experts_per_shard)`` in ``routing_map`` + # and ``sparse_probs``. + # + # For the Triton backend we additionally mask the remote-expert + # columns to False/0 so ``token_dispatch`` never writes those + # tokens into the local permuted buffer. For the pure-JAX backend + # we leave the routing_map untouched (mirroring Maxtext): the roll + # passed to ``mt_token_dispatch`` sorts remote-expert tokens past + # the local slots, and we later zero out those garbage rows of + # ``expert_outputs`` before the combine. + roll = num_experts_per_shard * shard_id + routing_map = jnp.roll(routing_map, -roll, axis=-1) + sparse_probs = jnp.roll(sparse_probs, -roll, axis=-1) + if self.permutation_backend == "triton": + local_expert_mask = ( + jnp.arange(self.num_experts) < num_experts_per_shard + ) + routing_map = routing_map * local_expert_mask[None, :] + sparse_probs = sparse_probs * local_expert_mask[None, :].astype( + sparse_probs.dtype + ) + + # Build a reduced-expert view of the weights: the outer ``shard_map`` + # has already sliced the leading expert axis down to + # ``num_experts_per_shard`` per shard. Pass it through as-is to the + # dispatch / expert-FFN path with ``num_experts_local = num_experts_per_shard``. + local_params = { + "gate_kernel": None, # unused past gate + "wi_0": local_wi_0, + "wi_1": local_wi_1, + "wo": local_wo, + } + if use_bias: + local_params["wi_0_bias"] = local_wi_0_bias + local_params["wi_1_bias"] = local_wi_1_bias + local_params["wo_bias"] = local_wo_bias + + expert_outputs, combine_state = self._dispatch_and_expert_ffn( + inputs_2d, + sparse_probs, + routing_map, + local_params, + num_experts_local=num_experts_per_shard, + roll_to_expert_id=0, # roll is already applied on routing_map + local_tokens_per_expert_count=num_experts_per_shard, + ) + + # For the pure-JAX backend in ring-EP mode, zero out expert-output + # rows that correspond to remote experts (which ``grouped_dense`` + # leaves as garbage since ``group_sizes`` was truncated to the + # local slice). Without this, the unsort + weighted-sum in + # combine would mix garbage into every token's output. Matches + # ``moe.py:1731-1733`` in Maxtext. + if self.permutation_backend == "pure_jax": + real_mask = ( + jnp.arange(expert_outputs.shape[0]) + < combine_state["local_real_size"] + ) + expert_outputs = jnp.where( + real_mask[:, None], expert_outputs, 0 + ) + + output = self._combine( + expert_outputs, + combine_state, + batch_size=batch_size, + sequence_length=sequence_length, + ) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + + # ``output`` is [B*num_ep, S, H] (global batch after all-gather); + # psum_scatter along EP returns the local [B, S, H] slice. + output = jax.lax.psum_scatter( + output, + ep_axis, + scatter_dimension=0, + tiled=True, + ) + + if aux_loss is None: + aux_loss = jnp.zeros((), dtype=self.dtype) + return output, aux_loss + + output, aux_loss = shard_map( + _ring_fn, + mesh=mesh, + in_specs=tuple(in_specs), + out_specs=out_specs, + check_rep=False, + )(*captured) + + if self.aux_loss_coeff <= 0.0: + aux_loss = None + return output, aux_loss + + # ------------------------------------------------------------------ + # Route + # ------------------------------------------------------------------ + + def _route( + self, + logits_2d: jnp.ndarray, + expert_bias: Optional[jnp.ndarray], + ) -> Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]]: + """Run the fused router and optional aux-loss.""" + sparse_probs, routing_map = fused_topk_with_score_function( + logits_2d, + topk=self.num_experts_per_tok, + use_pre_softmax=self.use_pre_softmax, + num_groups=self.num_groups, + group_topk=self.group_topk, + scaling_factor=self.scaling_factor, + score_function=self.score_function, + expert_bias=expert_bias, + ) + sparse_probs = sparse_probs.astype(self.dtype) + + aux_loss = None + if self.aux_loss_coeff > 0.0: + # The score-for-aux kernel runs independently (no data dependency + # on the main kernel), so XLA can overlap them on the GPU. + aux_scores, aux_routing_map = fused_topk_with_score_function( + logits_2d, + topk=self.num_experts_per_tok, + score_function=self.score_function, + compute_aux_scores=True, + ) + aux_tokens_per_expert = jnp.sum( + aux_routing_map.astype(jnp.int32), axis=0 + ) + aux_loss = fused_moe_aux_loss( + aux_scores, + aux_tokens_per_expert, + topk=self.num_experts_per_tok, + coeff=self.aux_loss_coeff, + ) + + return sparse_probs, routing_map, aux_loss + + # ------------------------------------------------------------------ + # Dispatch + expert FFN + # ------------------------------------------------------------------ + + def _dispatch_and_expert_ffn( + self, + inputs_2d: jnp.ndarray, + sparse_probs: jnp.ndarray, + routing_map: jnp.ndarray, + params: dict, + num_experts_local: int, + roll_to_expert_id: Optional[int], + local_tokens_per_expert_count: int, + ) -> Tuple[jnp.ndarray, dict]: + """Dispatch tokens, run the three grouped GEMMs + activation, return expert outputs. + + Returns a tuple ``(expert_outputs, combine_state)`` where + ``combine_state`` carries the per-backend state needed to rebuild the + original token ordering in :meth:`_combine`. + """ + num_tokens = inputs_2d.shape[0] + topk = self.num_experts_per_tok + + if self.permutation_backend == "pure_jax": + selected_experts, routing_weights = _extract_topk_from_routing_map( + sparse_probs, routing_map, topk + ) + sorted_inputs, perm_state, group_sizes = mt_token_dispatch( + inputs_2d, + selected_experts, + num_experts=self.num_experts, + num_experts_per_tok=topk, + align_size=self.align_size, + roll_to_expert_id=roll_to_expert_id, + use_custom_sort_vjp=self.use_custom_sort_vjp, + ) + # Slice group_sizes to just this shard's experts. When not using + # EP, ``num_experts_local == self.num_experts`` so this is a no-op. + group_sizes = group_sizes[:local_tokens_per_expert_count] + # ``local_real_size = sum(group_sizes)`` is the number of permuted + # rows that actually correspond to tokens routed to this shard's + # experts. Used by the ring-EP caller to zero out garbage rows + # before combine. + combine_state = { + "backend": "pure_jax", + "perm_state": perm_state, + "routing_weights": routing_weights, + "local_real_size": jnp.sum(group_sizes), + } + else: # "triton" + num_out_tokens = num_tokens * topk + align_size_arg = self.align_size if self.align_size > 0 else None + ( + sorted_inputs, + _permuted_probs, + row_id_map, + pad_offsets, + group_sizes, + ) = token_dispatch( + inputs_2d, + routing_map, + num_out_tokens=num_out_tokens, + probs=sparse_probs, + align_size=align_size_arg, + ) + group_sizes = group_sizes[:local_tokens_per_expert_count] + combine_state = { + "backend": "triton", + "row_id_map": row_id_map, + "pad_offsets": pad_offsets, + "merging_probs": sparse_probs, + "group_sizes": group_sizes, + } + + # ------------------------------------------------------------------ + # Expert FFN: grouped GEMMs w0, w1 + activation + w_o. + # ------------------------------------------------------------------ + wi_0 = params["wi_0"] + wi_1 = params["wi_1"] + wo = params["wo"] + + # Each grouped_dense call gets its own quantizer_set with + # ``n_groups=num_experts_local``; this matches the shape of + # ``group_sizes`` passed in and keeps the quantizer FP8 meta correctly + # sized per shard. + q_set_w0 = self.generate_quantizer_set( + postfix="_w0", n_groups=num_experts_local + ) + q_set_w1 = self.generate_quantizer_set( + postfix="_w1", n_groups=num_experts_local + ) + q_set_wo = self.generate_quantizer_set( + postfix="_wo", n_groups=num_experts_local + ) + + # Cast kernels to the sort dtype when no FP8 quantization is active + # (mirrors DenseGeneral). + if q_set_w0 == noop_quantizer_set: + wi_0 = wi_0.astype(sorted_inputs.dtype) + if q_set_w1 == noop_quantizer_set: + wi_1 = wi_1.astype(sorted_inputs.dtype) + if q_set_wo == noop_quantizer_set: + wo = wo.astype(sorted_inputs.dtype) + + # ``grouped_dense`` accepts per-expert bias of shape (G, N); it adds + # ``bias[i]`` to the ``group_sizes[i]`` rows belonging to expert ``i`` + # in the permuted layout. + wi_0_bias = params.get("wi_0_bias") if self.use_bias else None + wi_1_bias = params.get("wi_1_bias") if self.use_bias else None + wo_bias = params.get("wo_bias") if self.use_bias else None + + layer_w0 = grouped_dense( + sorted_inputs, + wi_0, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wi_0_bias, + quantizer_set=q_set_w0, + ) + layer_w1 = grouped_dense( + sorted_inputs, + wi_1, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wi_1_bias, + quantizer_set=q_set_w1, + ) + + act_fn = _get_activation_fn(self.activation_type) + intermediate = act_fn(layer_w0) * layer_w1 + + expert_outputs = grouped_dense( + intermediate, + wo, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wo_bias, + quantizer_set=q_set_wo, + ) + + return expert_outputs, combine_state + + # ------------------------------------------------------------------ + # Combine + # ------------------------------------------------------------------ + + def _combine( + self, + expert_outputs: jnp.ndarray, + combine_state: dict, + batch_size: int, + sequence_length: int, + ) -> jnp.ndarray: + if combine_state["backend"] == "pure_jax": + return mt_token_combine( + expert_outputs, + combine_state["perm_state"], + combine_state["routing_weights"], + num_experts_per_tok=self.num_experts_per_tok, + batch_size=batch_size, + sequence_length=sequence_length, + use_custom_sort_vjp=self.use_custom_sort_vjp, + ) + # triton + out_2d = token_combine( + expert_outputs, + combine_state["row_id_map"], + merging_probs=combine_state["merging_probs"], + pad_offsets=combine_state["pad_offsets"], + ) + hidden_size = out_2d.shape[-1] + return out_2d.reshape(batch_size, sequence_length, hidden_size).astype( + self.dtype + ) diff --git a/transformer_engine/jax/mt_permutation.py b/transformer_engine/jax/mt_permutation.py new file mode 100644 index 0000000000..10882501ec --- /dev/null +++ b/transformer_engine/jax/mt_permutation.py @@ -0,0 +1,356 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Pure-JAX MoE Permutation API. + +This module provides a MaxText-style, pure-JAX implementation of MoE token +dispatch / combine as an alternative to the Triton-backed primitives in +``transformer_engine.jax.permutation``. Empirically this path has been faster +than the Triton kernels on several E2E workloads. + +The core design mirrors Maxtext's ``_mt_permute`` / ``_mt_unpermute`` in +``maxtext/src/maxtext/layers/moe.py``, with alignment-padding support ported +from `nvjax-svc-0/maxtext PR #36 `_ +so each expert's group size is a multiple of ``align_size`` (required for +quantized grouped GEMM whose recipe-specific alignment must divide +``align_size``). + +When ``align_size = 0`` padding is disabled (faster for the unquantized path); +when ``align_size > 0`` a static-size padding buffer of shape +``[num_experts * (align_size - 1)]`` is appended before the sort so the overall +shape is JIT-compatible. + +The public API is: + +* :func:`mt_token_dispatch` -- pure-JAX counterpart of ``token_dispatch``. +* :func:`mt_token_combine` -- pure-JAX counterpart of ``token_combine``. +* :class:`MTPermState` -- opaque state returned by ``mt_token_dispatch`` and + consumed by ``mt_token_combine``. +""" + +from typing import NamedTuple, Optional, Tuple + +import jax +import jax.numpy as jnp + +__all__ = [ + "MTPermState", + "mt_token_dispatch", + "mt_token_combine", +] + + +# ============================================================================= +# Custom-VJP argsort-based gather (``_sort_activations_custom``) +# ============================================================================= +# +# ``inputs[sort_indices]`` has a known inverse: ``output[argsort(sort_indices)]``. +# Using a custom VJP lets the backward pass exploit that inverse instead of +# relying on the compiler to discover it from the scatter-style default +# gradient of a gather, which is typically less efficient. + + +@jax.custom_vjp +def _sort_activations_custom(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array: + """Sort ``inputs`` along the leading dim by ``sort_indices``.""" + return inputs[sort_indices, ...] + + +def _sort_activations_custom_fwd( + inputs: jax.Array, sort_indices: jax.Array +) -> Tuple[jax.Array, jax.Array]: + return _sort_activations_custom(inputs, sort_indices), sort_indices + + +def _sort_activations_custom_bwd( + residuals: jax.Array, grads: jax.Array +) -> Tuple[jax.Array, None]: + sort_indices = residuals + # Inverse permutation: gather-by-argsort undoes the forward gather. + return _sort_activations_custom(grads, jnp.argsort(sort_indices)), None + + +_sort_activations_custom.defvjp(_sort_activations_custom_fwd, _sort_activations_custom_bwd) + + +def _sort_activations( + inputs: jax.Array, + sort_indices: jax.Array, + use_custom_vjp: bool, +) -> jax.Array: + """Sort activations by ``sort_indices``, optionally with the custom VJP.""" + assert inputs.shape[0] == sort_indices.shape[0], ( + f"inputs.shape[0]={inputs.shape[0]} must match" + f" sort_indices.shape[0]={sort_indices.shape[0]}" + ) + with jax.named_scope("mt_sort_activations"): + if use_custom_vjp: + return _sort_activations_custom(inputs, sort_indices) + return inputs[sort_indices, ...] + + +# ============================================================================= +# Permutation state carried from dispatch to combine +# ============================================================================= + + +class MTPermState(NamedTuple): + """Opaque state produced by :func:`mt_token_dispatch`. + + Attributes + ---------- + sorted_indices : jnp.ndarray + The argsort indices used in the forward sort. Needed to reverse the + permutation in :func:`mt_token_combine`. Shape + ``[num_real_tokens + padding_size]``. + num_real_tokens : int + Number of real (non-padding) permuted tokens, i.e. + ``batch_size * sequence_length * num_experts_per_tok``. Compile-time + constant. + padding_size : int + Number of alignment-padding tokens appended to the sort buffer. Equals + ``num_experts * (align_size - 1)`` when ``align_size > 0``, else ``0``. + Compile-time constant. + """ + + sorted_indices: jax.Array + num_real_tokens: int + padding_size: int + + +# ============================================================================= +# Dispatch (permute) +# ============================================================================= + + +def mt_token_dispatch( + inputs: jnp.ndarray, + selected_experts: jnp.ndarray, + num_experts: int, + num_experts_per_tok: int, + align_size: int = 0, + roll_to_expert_id: Optional[int] = None, + use_custom_sort_vjp: bool = True, +) -> Tuple[jnp.ndarray, MTPermState, jnp.ndarray]: + """Pure-JAX MaxText-style token dispatch. + + Parameters + ---------- + inputs : jnp.ndarray + Input tensor of shape ``[num_tokens, hidden_size]`` (or + ``[batch, seq, hidden]``; it will be flattened). + selected_experts : jnp.ndarray + Per-token expert IDs, shape ``[num_tokens, num_experts_per_tok]`` (or + ``[batch, seq, num_experts_per_tok]``). Integer dtype. + num_experts : int + Total number of experts. + num_experts_per_tok : int + Top-k. Must equal ``selected_experts.shape[-1]``. + align_size : int, default 0 + Alignment for each expert's group size. ``0`` disables padding; a value + ``> 0`` appends a static-size padding buffer so each resulting group + size is a multiple of ``align_size``. + roll_to_expert_id : Optional[int] + If provided, rotates expert IDs by ``-roll_to_expert_id`` modulo + ``num_experts`` before the sort (ring-of-experts EP). The returned + ``group_sizes`` is rolled to match. + use_custom_sort_vjp : bool, default True + Whether to use the custom-VJP argsort gather for the sort. + + Returns + ------- + sorted_inputs : jnp.ndarray + Permuted tokens grouped by expert, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : MTPermState + State needed by :func:`mt_token_combine`. + group_sizes : jnp.ndarray + Token count per expert, shape ``[num_experts]``. Each entry is a + multiple of ``align_size`` when ``align_size > 0``. + """ + assert num_experts_per_tok == selected_experts.shape[-1], ( + f"num_experts_per_tok={num_experts_per_tok} must match" + f" selected_experts.shape[-1]={selected_experts.shape[-1]}" + ) + assert align_size >= 0, f"align_size must be >= 0, got {align_size}" + + hidden_size = inputs.shape[-1] + # Flatten token dims. + inputs_2d = inputs.reshape(-1, hidden_size) + num_tokens = inputs_2d.shape[0] + num_real_tokens = num_tokens * num_experts_per_tok + + flatten_selected_experts = jnp.ravel(selected_experts) + + if align_size > 0: + # Per-expert token count, and how many extra tokens each expert needs + # to become aligned to ``align_size``. Using + # ``(align - count % align) % align`` gives 0 (not ``align``) when + # already aligned, so we never exceed the per-expert slot capacity of + # ``align_size - 1``. + token_count_per_expert = jnp.bincount( + flatten_selected_experts, length=num_experts + ) + padding_tokens_required_per_expert = ( + (align_size - (token_count_per_expert % align_size)) % align_size + ) + + # Build a static-size padding buffer of shape + # ``[num_experts * (align_size - 1)]``. Each expert ``i`` owns a slot + # of ``align_size - 1`` positions (worst-case padding, which occurs + # when ``token_count[i] % align_size == 1``). Within slot ``i``, + # positions ``[0, padding_needed)`` are assigned expert ``i`` and act + # as real padding; the rest are assigned to ``num_experts - 1`` as + # overflow placeholders that keep the buffer statically sized for JIT. + max_padding_per_expert = align_size - 1 + max_total_padding_size = num_experts * max_padding_per_expert + positions = jnp.arange(max_total_padding_size) + expert_for_pos = positions // max_padding_per_expert + offset_in_slot = positions % max_padding_per_expert + padding_needed = padding_tokens_required_per_expert[expert_for_pos] + flatten_padding_selected_experts = jnp.where( + offset_in_slot < padding_needed, + expert_for_pos, + num_experts - 1, + ) + + flatten_selected_experts = jnp.concatenate( + [flatten_selected_experts, flatten_padding_selected_experts], axis=0 + ) + + if roll_to_expert_id is not None: + flatten_selected_experts = ( + flatten_selected_experts - roll_to_expert_id + ) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + # Pad inputs with zeros so the sort operand shape matches the expanded + # selected-experts vector. + replicated_inputs_2d = jnp.pad( + replicated_inputs_2d, + pad_width=((0, max_total_padding_size), (0, 0)), + mode="constant", + constant_values=0.0, + ) + + sorted_inputs = _sort_activations( + replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp + ) + + # Compute ``group_sizes`` directly from counts rather than via + # ``bincount(flatten_selected_experts)``: the overflow placeholder + # tokens would inflate ``group_sizes[num_experts - 1]``, breaking the + # alignment guarantee. Direct computation gives each expert exactly + # ``ceil(count / align) * align`` tokens. + group_sizes = token_count_per_expert + padding_tokens_required_per_expert + + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = max_total_padding_size + else: + if roll_to_expert_id is not None: + flatten_selected_experts = ( + flatten_selected_experts - roll_to_expert_id + ) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + sorted_inputs = _sort_activations( + replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp + ) + + group_sizes = jnp.bincount(flatten_selected_experts, length=num_experts) + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = 0 + + perm_state = MTPermState( + sorted_indices=sorted_selected_experts, + num_real_tokens=num_real_tokens, + padding_size=padding_size, + ) + return sorted_inputs, perm_state, group_sizes + + +# ============================================================================= +# Combine (unpermute + weighted sum) +# ============================================================================= + + +def mt_token_combine( + expert_outputs: jnp.ndarray, + perm_state: MTPermState, + routing_weights: jnp.ndarray, + num_experts_per_tok: int, + batch_size: int, + sequence_length: int, + use_custom_sort_vjp: bool = True, +) -> jnp.ndarray: + """Pure-JAX MaxText-style token combine. + + Reverses the permutation performed by :func:`mt_token_dispatch`, strips + any alignment-padding rows appended during dispatch, and applies a + per-token weighted sum across the top-k experts. + + Parameters + ---------- + expert_outputs : jnp.ndarray + Output of the expert FFN, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : MTPermState + State returned by :func:`mt_token_dispatch`. + routing_weights : jnp.ndarray + Top-k routing weights, shape ``[batch*seq, num_experts_per_tok]`` + (or broadcastable to it after a ``reshape``). + num_experts_per_tok : int + Top-k. + batch_size : int + Original batch size. + sequence_length : int + Original sequence length. + use_custom_sort_vjp : bool, default True + Whether to use the custom-VJP argsort gather for the unsort. + + Returns + ------- + output : jnp.ndarray + Combined output tensor of shape ``[batch_size, sequence_length, hidden_size]``. + """ + # Reverse the permutation: ``output[argsort(sorted_indices)]`` undoes + # ``input[sorted_indices]``. + unsort_intermediate = _sort_activations( + expert_outputs, + jnp.argsort(perm_state.sorted_indices), + use_custom_sort_vjp, + ) + + # Strip alignment padding tokens appended during dispatch. After unsorting, + # the first ``num_real_tokens`` rows hold the real per-(token, top-k) + # outputs; any trailing rows are padding placeholders (zeros) and must be + # discarded before the reshape below. + if perm_state.padding_size > 0: + unsort_intermediate = unsort_intermediate[: perm_state.num_real_tokens] + + hidden_size = unsort_intermediate.shape[-1] + reshaped_weights = jnp.reshape(routing_weights, (-1, num_experts_per_tok)) + reshaped_intermediate = jnp.reshape( + unsort_intermediate, (reshaped_weights.shape[0], num_experts_per_tok, hidden_size) + ) + + # Cast weights to match intermediate dtype (weighted sum happens in + # intermediate dtype; callers can upcast before calling if higher + # precision weight-sum is desired). + reshaped_weights = reshaped_weights.astype(reshaped_intermediate.dtype) + with jax.named_scope("mt_weight_sum"): + output = jnp.einsum( + "BKE,BK -> BE", + reshaped_intermediate, + reshaped_weights, + ) + return output.reshape(batch_size, sequence_length, hidden_size) From aeb4fcf6e2700c1eef2788b7e1ec97b3ced6a25b Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 21 Apr 2026 17:24:16 -0700 Subject: [PATCH 3/6] clean up any link to Maxtext. Permutation backends. clean up foward body single GPU vs. multi GPU Signed-off-by: tdophung --- transformer_engine/jax/flax/moe.py | 492 +++++++++-------------- transformer_engine/jax/mt_permutation.py | 356 ---------------- transformer_engine/jax/permutation.py | 336 +++++++++++++++- 3 files changed, 514 insertions(+), 670 deletions(-) delete mode 100644 transformer_engine/jax/mt_permutation.py diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index ddbe687771..6673ac1a71 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -6,11 +6,8 @@ This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE layer that wires together TE's fused router, a selectable token-dispatch backend -(pure-JAX MaxText-style or Triton), TE's ``grouped_dense``, and optional +(pure-JAX ``unfused_*`` or fused Triton), TE's ``grouped_dense``, and optional ring-of-experts Expert Parallelism. - -See ``plans/te_jax_moeblock_926b7994.plan.md`` for the full design rationale -and the mapping to Maxtext's ``RoutedMoE``. """ from typing import Any, Callable, NewType, Optional, Tuple, Union @@ -21,12 +18,17 @@ from jax.sharding import PartitionSpec as P from ..dense import grouped_dense -from ..mt_permutation import mt_token_combine, mt_token_dispatch -from ..permutation import token_combine, token_dispatch +from ..permutation import ( + _routing_map_to_selected_experts, + token_combine, + token_dispatch, + unfused_token_combine, + unfused_token_dispatch, +) from ..quantize import noop_quantizer_set from ..router import ScoreFunction, fused_moe_aux_loss, fused_topk_with_score_function from ..sharding import with_sharding_constraint_by_logical_axes -from .module import TransformerEngineBase +from .module import TransformerEngineBase, _convert_to_activation_function PRNGKey = Any Shape = Tuple[int, ...] @@ -38,57 +40,6 @@ __all__ = ["MoEBlock"] -# ============================================================================= -# Helpers -# ============================================================================= - - -_ACTIVATIONS = { - "silu": jax.nn.silu, - "swish": jax.nn.silu, - "gelu": jax.nn.gelu, - "relu": jax.nn.relu, - "identity": lambda x: x, - "linear": lambda x: x, -} - - -def _get_activation_fn(name: str) -> Callable: - key = name.lower() - if key not in _ACTIVATIONS: - raise ValueError( - f"Unsupported activation_type={name!r}; supported: {sorted(_ACTIVATIONS)}" - ) - return _ACTIVATIONS[key] - - -def _extract_topk_from_routing_map( - sparse_probs: jnp.ndarray, - routing_map: jnp.ndarray, - topk: int, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Convert TE's ``(sparse_probs, routing_map)`` to ``(selected_experts, weights)``. - - ``routing_map`` is a boolean mask of shape ``[num_tokens, num_experts]`` - with exactly ``topk`` ``True`` positions per row. ``sparse_probs`` is the - same-shape float tensor whose non-zero entries are the routing weights. - - The per-token top-k expert IDs are recovered as the last ``topk`` indices - of ``argsort(routing_map)`` (``False < True``), and the corresponding - weights are gathered from ``sparse_probs`` along the expert axis. - - The within-row expert ordering does not have to match the router's - top-k ordering: :func:`mt_token_dispatch` and :func:`mt_token_combine` - only require that ``selected_experts`` and ``weights`` are consistent with - each other. - """ - # Cast to int32 so argsort has a well-defined ordering. (Ascending argsort - # on 0/1 puts the ``True`` positions last; we then slice the last ``topk``.) - selected_experts = jnp.argsort(routing_map.astype(jnp.int32), axis=-1)[:, -topk:] - weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) - return selected_experts, weights - - # ============================================================================= # MoEBlock # ============================================================================= @@ -102,11 +53,11 @@ class MoEBlock(TransformerEngineBase): two-layer FFN via grouped GEMMs, activation, token combine, and optional ring-of-experts expert parallelism. - The permutation step is pluggable: the default ``permutation_backend="pure_jax"`` - uses the MaxText-style argsort-based dispatch/combine in - :mod:`transformer_engine.jax.mt_permutation`, which empirically outperforms - the Triton kernels on several E2E workloads. ``permutation_backend="triton"`` - uses TE's ``token_dispatch`` / ``token_combine`` kernels. + The permutation step is pluggable via ``permutation_backend``: + ``"pure_jax"`` (default) uses the pure-JAX argsort-based + ``unfused_token_dispatch`` / ``unfused_token_combine`` in + :mod:`transformer_engine.jax.permutation`; ``"triton"`` uses TE's fused + ``token_dispatch`` / ``token_combine`` kernels. Parameters ---------- @@ -119,9 +70,9 @@ class MoEBlock(TransformerEngineBase): activation_type : str FFN activation applied to the gate projection. Paired with the up - projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. Supported: - ``"silu"``/``"swish"`` (default), ``"gelu"``, ``"relu"``, - ``"identity"``/``"linear"``. + projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. Resolved + via :func:`flax.linen.` (``"silu"``, ``"gelu"``, ``"relu"``, + ``"swish"``, ...) plus ``"linear"`` for identity. score_function : str or ScoreFunction ``"softmax"`` (default) or ``"sigmoid"`` for :func:`fused_topk_with_score_function`. @@ -135,8 +86,8 @@ class MoEBlock(TransformerEngineBase): Scaling factor applied to output probs. use_expert_bias : bool If ``True``, registers a learnable ``expert_bias`` parameter of shape - ``[num_experts]`` and passes it to the fused router. Only valid with - ``score_function="sigmoid"`` (DeepSeek V3 loss-free load balancing). + ``[num_experts]`` and passes it to the fused router. The router + primitive validates that this is paired with ``score_function="sigmoid"``. aux_loss_coeff : float If ``> 0``, compute and return the MoE auxiliary load-balancing loss scalar via :func:`fused_moe_aux_loss`. ``0`` disables. @@ -171,21 +122,18 @@ class MoEBlock(TransformerEngineBase): the end of the forward pass). permutation_backend : str - ``"pure_jax"`` (default; faster on many E2E workloads) or ``"triton"``. + ``"pure_jax"`` (default) or ``"triton"``. align_size : int Alignment for per-expert group sizes after padding. ``0`` disables padding (faster for the unquantized path). ``>0`` is required for quantized TE grouped GEMM whose recipe-specific alignment must divide - ``align_size``. Passed through to both permutation backends. - use_custom_sort_vjp : bool - Only used when ``permutation_backend="pure_jax"``. If ``True``, uses - a custom VJP for the argsort-based gather (faster in most cases). + ``align_size``. dtype : jnp.dtype Compute and parameter dtype. kernel_init : Initializer - Initializer for all kernels. Defaults to ``variance_scaling(1.0, - 'fan_in', 'truncated_normal')``. + Initializer for all kernels (gate + per-expert FFN). Defaults to + ``variance_scaling(1.0, 'fan_in', 'truncated_normal')`` (Flax convention). use_bias : bool If ``True``, registers per-expert FFN biases ``wi_0_bias``, ``wi_1_bias``, ``wo_bias``. @@ -222,7 +170,6 @@ class MoEBlock(TransformerEngineBase): # Permutation permutation_backend: str = "pure_jax" align_size: int = 0 - use_custom_sort_vjp: bool = True # Dtypes / init / misc dtype: DType = jnp.float32 @@ -245,20 +192,6 @@ def __post_init__(self): "permutation_backend must be 'pure_jax' or 'triton'," f" got {self.permutation_backend!r}" ) - if self.use_expert_bias: - # ``fused_topk_with_score_function`` only accepts ``expert_bias`` - # under the sigmoid score function. Raise early to surface the - # misconfiguration instead of failing deep inside the kernel. - score_func = ( - self.score_function.name.lower() - if isinstance(self.score_function, ScoreFunction) - else str(self.score_function).lower() - ) - if score_func != "sigmoid": - raise ValueError( - "use_expert_bias=True requires score_function='sigmoid';" - f" got {self.score_function!r}." - ) super().__post_init__() # ------------------------------------------------------------------ @@ -330,19 +263,13 @@ def _make_params(self, hidden_size: int): # ------------------------------------------------------------------ @nn.compact - def __call__( - self, - inputs: Array, - deterministic: bool = True, - ) -> Tuple[Array, Optional[Array]]: + def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: """Run the MoE forward pass. Parameters ---------- inputs : jnp.ndarray Input tensor of shape ``[batch, sequence, hidden]``. - deterministic : bool - Reserved for future dropout-based routing; currently unused. Returns ------- @@ -352,24 +279,39 @@ def __call__( Scalar auxiliary load-balancing loss when ``aux_loss_coeff > 0``, else ``None``. """ - del deterministic # unused for now - assert inputs.ndim == 3, ( f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" ) inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) - batch_size, sequence_length, hidden_size = inputs.shape + _, _, hidden_size = inputs.shape params = self._make_params(hidden_size) - # Gate projection runs OUTSIDE the EP shard_map (mirroring Maxtext), - # so that each EP shard projects its own local slice of tokens and we - # later all-gather only the logits, not the full inputs. + # Gate runs OUTSIDE the EP shard_map below, so each EP shard projects + # its own local slice of tokens and we later all-gather only the + # smaller logits tensor instead of the full inputs. gate_logits = self._gate(inputs, params["gate_kernel"]) - if self.expert_parallelism_axis is not None: - return self._forward_ring_ep(inputs, gate_logits, params) - return self._forward_single_shard(inputs, gate_logits, params) + if self.expert_parallelism_axis is None: + # No EP: each primitive's own ``custom_partitioning`` rule handles + # DP / FSDP / TP across the mesh - no shard_map needed. + output, aux_loss = self._forward_body( + inputs, + gate_logits, + params, + num_experts_local=self.num_experts, + roll_to_expert_id=None, + ) + else: + # Ring-EP: ``_forward_body`` is wrapped in a shard_map that + # orchestrates the cross-primitive collectives (all_gather inputs + # / logits before, psum_scatter output after) which per-primitive + # ``custom_partitioning`` cannot express on its own. + output, aux_loss = self._forward_ring_ep(inputs, gate_logits, params) + + if self.aux_loss_coeff <= 0.0: + aux_loss = None + return output, aux_loss # ------------------------------------------------------------------ # Gate @@ -379,26 +321,34 @@ def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray: """Linear gate projection ``inputs @ gate_kernel``. Kept as a plain matmul (not ``DenseGeneral``) so it integrates cleanly - with the EP shard_map below: the gate matmul runs in the outer - (pre-shard_map) scope and its output is all-gathered along the EP axis - inside the shard_map. + with the EP shard_map: the gate matmul runs in the outer (pre-shard_map) + scope and its output is all-gathered along the EP axis inside. """ # Cast kernel to input dtype outside FP8 scope (gate is typically BF16/FP32). kernel = gate_kernel.astype(inputs.dtype) return jnp.einsum("bsh,he->bse", inputs, kernel) # ------------------------------------------------------------------ - # Single-shard (no EP) forward + # Forward body (shared between no-EP and ring-EP paths) # ------------------------------------------------------------------ - def _forward_single_shard( + def _forward_body( self, inputs: jnp.ndarray, gate_logits: jnp.ndarray, params: dict, + num_experts_local: int, + roll_to_expert_id: Optional[int], ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - batch_size, sequence_length, hidden_size = inputs.shape + """Routing + dispatch + per-expert FFN + combine. + Used both bare (no EP) and inside the ring-EP shard_map. In the + ring-EP case ``inputs`` and ``gate_logits`` are the post-all_gather + global tensors, ``num_experts_local == num_experts // num_ep``, and + ``roll_to_expert_id`` is the offset that brings this shard's experts + into slots ``[0, num_experts_local)``. + """ + batch_size, sequence_length, hidden_size = inputs.shape inputs_2d = inputs.reshape(-1, hidden_size) logits_2d = gate_logits.reshape(-1, self.num_experts) @@ -406,16 +356,48 @@ def _forward_single_shard( logits_2d, params.get("expert_bias") ) + if roll_to_expert_id is not None: + # Rotate expert columns so this shard's experts come first. + routing_map = jnp.roll(routing_map, -roll_to_expert_id, axis=-1) + sparse_probs = jnp.roll(sparse_probs, -roll_to_expert_id, axis=-1) + if self.permutation_backend == "triton": + # Triton path: zero out remote-expert columns so the fused + # ``token_dispatch`` never writes tokens routed off-shard. + # The pure-JAX path zeroes garbage *output* rows below + # instead, since masking the routing_map directly would + # break the argsort-based permutation. + local_mask = ( + jnp.arange(self.num_experts) < num_experts_local + ) + routing_map = routing_map * local_mask + sparse_probs = sparse_probs * local_mask.astype(sparse_probs.dtype) + expert_outputs, combine_state = self._dispatch_and_expert_ffn( inputs_2d, sparse_probs, routing_map, params, - num_experts_local=self.num_experts, - roll_to_expert_id=None, - local_tokens_per_expert_count=self.num_experts, + num_experts_local=num_experts_local, + # The roll is already baked into ``routing_map``/``sparse_probs`` + # above, so the unfused dispatch must not roll again. + roll_to_expert_id=0 if roll_to_expert_id is not None else None, ) + if ( + roll_to_expert_id is not None + and self.permutation_backend == "pure_jax" + ): + # Zero the rows of ``expert_outputs`` past the real local-expert + # token count: ``grouped_dense`` leaves them as garbage because + # ``group_sizes`` was truncated to the local slice. Without this + # the unsort + weighted-sum in combine would mix garbage into + # every token's output (mirrors Maxtext's moe.py). + real_mask = ( + jnp.arange(expert_outputs.shape[0]) + < combine_state["local_real_size"] + ) + expert_outputs = jnp.where(real_mask[:, None], expert_outputs, 0) + output = self._combine( expert_outputs, combine_state, @@ -434,7 +416,7 @@ def _forward_single_shard( return output, aux_loss # ------------------------------------------------------------------ - # Ring-of-Experts EP forward + # Ring-of-Experts EP wrapper # ------------------------------------------------------------------ def _forward_ring_ep( @@ -442,22 +424,16 @@ def _forward_ring_ep( inputs: jnp.ndarray, gate_logits: jnp.ndarray, params: dict, - ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Wrap the dispatch / FFN / combine pipeline in a ring-of-experts - ``shard_map``. - - Inside the shard_map each EP shard: - 1. ``all_gather`` s the inputs and logits along the EP axis so it - sees every token globally. - 2. Routes with ``roll_to_expert_id = num_experts_per_shard * shard_id`` - so its local experts are in slots ``[0, num_experts_per_shard)``. - 3. Dispatches tokens, slicing ``group_sizes`` to the first - ``num_experts_per_shard`` entries (the rest correspond to remote - experts and should be zero after the roll/mask). - 4. Runs the per-expert FFN on its local expert slice of - ``wi_0`` / ``wi_1`` / ``wo``. - 5. Combines at the expanded-batch shape ``[B * num_ep, S, H]`` then - ``psum_scatter`` s along the EP axis to return the local slice. + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Wrap :meth:`_forward_body` in a ring-of-experts ``shard_map``. + + For each EP shard the wrapper: + 1. ``all_gather`` s the local inputs / logits / expert_bias along + the EP axis so the routing sees every token globally. + 2. Calls ``_forward_body`` with ``roll_to_expert_id = + num_experts_per_shard * shard_id`` and the EP-local weight slice. + 3. ``psum_scatter`` s the resulting ``[B*num_ep, S, H]`` output back + to the EP-sharded ``[B, S, H]`` layout. """ from jax.experimental.shard_map import shard_map @@ -474,201 +450,94 @@ def _forward_ring_ep( ) num_experts_per_shard = self.num_experts // num_ep - # in_specs / out_specs use PartitionSpec over the EP axis for inputs/ - # outputs (leading batch dim is split across EP) and ``P("exp", ...)`` - # for the expert weights, where we require the user's logical axis - # rules to map ``"exp"`` to the EP mesh axis. The expert bias is - # similarly sharded along the expert axis. - inputs_spec = P(ep_axis, None, None) - logits_spec = P(ep_axis, None, None) - wi_spec = P(ep_axis, None, None) - wo_spec = P(ep_axis, None, None) - output_spec = P(ep_axis, None, None) - scalar_spec = P() - bias_1d_spec = P(ep_axis) - bias_2d_spec = P(ep_axis, None) - - expert_bias_value = params.get("expert_bias") - wi_0_bias_value = params.get("wi_0_bias") - wi_1_bias_value = params.get("wi_1_bias") - wo_bias_value = params.get("wo_bias") - - in_specs = [ - inputs_spec, - logits_spec, - wi_spec, - wi_spec, - wo_spec, - ] - captured = [ - inputs, - gate_logits, - params["wi_0"], - params["wi_1"], - params["wo"], - ] - if expert_bias_value is not None: - in_specs.append(bias_1d_spec) - captured.append(expert_bias_value) - if wi_0_bias_value is not None: - in_specs.extend([bias_2d_spec, bias_2d_spec, bias_2d_spec]) - captured.extend([wi_0_bias_value, wi_1_bias_value, wo_bias_value]) - - out_specs = (output_spec, scalar_spec) - - use_expert_bias = expert_bias_value is not None - use_bias = wi_0_bias_value is not None - - def _ring_fn(*args): - idx = 0 - local_inputs = args[idx]; idx += 1 - local_gate_logits = args[idx]; idx += 1 - local_wi_0 = args[idx]; idx += 1 - local_wi_1 = args[idx]; idx += 1 - local_wo = args[idx]; idx += 1 - local_expert_bias = None - if use_expert_bias: - local_expert_bias = args[idx]; idx += 1 - local_wi_0_bias = local_wi_1_bias = local_wo_bias = None - if use_bias: - local_wi_0_bias = args[idx]; idx += 1 - local_wi_1_bias = args[idx]; idx += 1 - local_wo_bias = args[idx]; idx += 1 - + # Pack everything that crosses the shard_map boundary into a dict + # pytree. shard_map fully supports pytrees: ``in_specs`` must + # structurally match ``captured``, and we build them in lockstep so + # adding/removing an optional bias is a single ``dict[name] = ...``. + captured: dict = { + "inputs": inputs, + "gate_logits": gate_logits, + "wi_0": params["wi_0"], + "wi_1": params["wi_1"], + "wo": params["wo"], + } + in_specs: dict = { + "inputs": P(ep_axis, None, None), + "gate_logits": P(ep_axis, None, None), + "wi_0": P(ep_axis, None, None), + "wi_1": P(ep_axis, None, None), + "wo": P(ep_axis, None, None), + } + if "expert_bias" in params: + captured["expert_bias"] = params["expert_bias"] + in_specs["expert_bias"] = P(ep_axis) + if "wi_0_bias" in params: + for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): + captured[name] = params[name] + in_specs[name] = P(ep_axis, None) + + def _ring_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: shard_id = jax.lax.axis_index(ep_axis) - # All-gather inputs and logits along the EP axis so each shard - # sees the global tokens. gathered_inputs = jax.lax.all_gather( - local_inputs, axis_name=ep_axis, tiled=True + local["inputs"], axis_name=ep_axis, tiled=True ) gathered_logits = jax.lax.all_gather( - local_gate_logits, axis_name=ep_axis, tiled=True - ) - - # If the user also sharded by EP on the expert_bias, ``local_expert_bias`` - # is already the local slice; the router operates over the full - # expert axis, so all-gather to reconstruct. - global_expert_bias = None - if local_expert_bias is not None: - global_expert_bias = jax.lax.all_gather( - local_expert_bias, axis_name=ep_axis, tiled=True - ) - - batch_size = gathered_inputs.shape[0] - sequence_length = gathered_inputs.shape[1] - hidden_size = gathered_inputs.shape[2] - - inputs_2d = gathered_inputs.reshape(-1, hidden_size) - logits_2d = gathered_logits.reshape(-1, self.num_experts) - - sparse_probs, routing_map, aux_loss = self._route( - logits_2d, global_expert_bias + local["gate_logits"], axis_name=ep_axis, tiled=True ) - # Ring-of-experts roll: after rolling expert columns by - # ``-num_experts_per_shard * shard_id``, this shard's experts - # occupy slots ``[0, num_experts_per_shard)`` in ``routing_map`` - # and ``sparse_probs``. - # - # For the Triton backend we additionally mask the remote-expert - # columns to False/0 so ``token_dispatch`` never writes those - # tokens into the local permuted buffer. For the pure-JAX backend - # we leave the routing_map untouched (mirroring Maxtext): the roll - # passed to ``mt_token_dispatch`` sorts remote-expert tokens past - # the local slots, and we later zero out those garbage rows of - # ``expert_outputs`` before the combine. - roll = num_experts_per_shard * shard_id - routing_map = jnp.roll(routing_map, -roll, axis=-1) - sparse_probs = jnp.roll(sparse_probs, -roll, axis=-1) - if self.permutation_backend == "triton": - local_expert_mask = ( - jnp.arange(self.num_experts) < num_experts_per_shard - ) - routing_map = routing_map * local_expert_mask[None, :] - sparse_probs = sparse_probs * local_expert_mask[None, :].astype( - sparse_probs.dtype - ) - - # Build a reduced-expert view of the weights: the outer ``shard_map`` - # has already sliced the leading expert axis down to - # ``num_experts_per_shard`` per shard. Pass it through as-is to the - # dispatch / expert-FFN path with ``num_experts_local = num_experts_per_shard``. - local_params = { - "gate_kernel": None, # unused past gate - "wi_0": local_wi_0, - "wi_1": local_wi_1, - "wo": local_wo, + local_params: dict = { + "wi_0": local["wi_0"], + "wi_1": local["wi_1"], + "wo": local["wo"], } - if use_bias: - local_params["wi_0_bias"] = local_wi_0_bias - local_params["wi_1_bias"] = local_wi_1_bias - local_params["wo_bias"] = local_wo_bias - - expert_outputs, combine_state = self._dispatch_and_expert_ffn( - inputs_2d, - sparse_probs, - routing_map, + if "expert_bias" in local: + # The router operates over the full expert axis, so the + # EP-sharded bias must be all-gathered. + local_params["expert_bias"] = jax.lax.all_gather( + local["expert_bias"], axis_name=ep_axis, tiled=True + ) + if "wi_0_bias" in local: + local_params["wi_0_bias"] = local["wi_0_bias"] + local_params["wi_1_bias"] = local["wi_1_bias"] + local_params["wo_bias"] = local["wo_bias"] + + output, aux_loss = self._forward_body( + gathered_inputs, + gathered_logits, local_params, num_experts_local=num_experts_per_shard, - roll_to_expert_id=0, # roll is already applied on routing_map - local_tokens_per_expert_count=num_experts_per_shard, - ) - - # For the pure-JAX backend in ring-EP mode, zero out expert-output - # rows that correspond to remote experts (which ``grouped_dense`` - # leaves as garbage since ``group_sizes`` was truncated to the - # local slice). Without this, the unsort + weighted-sum in - # combine would mix garbage into every token's output. Matches - # ``moe.py:1731-1733`` in Maxtext. - if self.permutation_backend == "pure_jax": - real_mask = ( - jnp.arange(expert_outputs.shape[0]) - < combine_state["local_real_size"] - ) - expert_outputs = jnp.where( - real_mask[:, None], expert_outputs, 0 - ) - - output = self._combine( - expert_outputs, - combine_state, - batch_size=batch_size, - sequence_length=sequence_length, + roll_to_expert_id=num_experts_per_shard * shard_id, ) - if self.tensor_parallelism_axis is not None: - output = jax.lax.psum_scatter( - output, - self.tensor_parallelism_axis, - scatter_dimension=2, - tiled=True, - ) - - # ``output`` is [B*num_ep, S, H] (global batch after all-gather); + # ``output`` is [B*num_ep, S, H] (global batch after all_gather); # psum_scatter along EP returns the local [B, S, H] slice. output = jax.lax.psum_scatter( - output, - ep_axis, - scatter_dimension=0, - tiled=True, + output, ep_axis, scatter_dimension=0, tiled=True ) + # ``out_specs`` must match the returned pytree structurally, so + # always emit a real scalar for aux_loss; the outer ``__call__`` + # re-strips it to None when ``aux_loss_coeff <= 0``. if aux_loss is None: aux_loss = jnp.zeros((), dtype=self.dtype) return output, aux_loss - output, aux_loss = shard_map( + # ``check_rep=False`` disables shard_map's invariant that any output + # declared as ``P()`` is replicated across ``ep_axis``. We use + # ``axis_index(ep_axis)`` inside ``_ring_fn`` to compute a per-shard + # roll, which makes the body genuinely non-replicated and would + # otherwise (correctly) fail the check. The ``psum_scatter`` of the + # output already produces the right cross-shard semantics; this is + # the standard JAX escape hatch when collectives + per-shard logic + # coexist. + return shard_map( _ring_fn, mesh=mesh, - in_specs=tuple(in_specs), - out_specs=out_specs, + in_specs=in_specs, + out_specs=(P(ep_axis, None, None), P()), check_rep=False, - )(*captured) - - if self.aux_loss_coeff <= 0.0: - aux_loss = None - return output, aux_loss + )(captured) # ------------------------------------------------------------------ # Route @@ -726,7 +595,6 @@ def _dispatch_and_expert_ffn( params: dict, num_experts_local: int, roll_to_expert_id: Optional[int], - local_tokens_per_expert_count: int, ) -> Tuple[jnp.ndarray, dict]: """Dispatch tokens, run the three grouped GEMMs + activation, return expert outputs. @@ -738,21 +606,20 @@ def _dispatch_and_expert_ffn( topk = self.num_experts_per_tok if self.permutation_backend == "pure_jax": - selected_experts, routing_weights = _extract_topk_from_routing_map( + selected_experts, routing_weights = _routing_map_to_selected_experts( sparse_probs, routing_map, topk ) - sorted_inputs, perm_state, group_sizes = mt_token_dispatch( + sorted_inputs, perm_state, group_sizes = unfused_token_dispatch( inputs_2d, selected_experts, num_experts=self.num_experts, num_experts_per_tok=topk, align_size=self.align_size, roll_to_expert_id=roll_to_expert_id, - use_custom_sort_vjp=self.use_custom_sort_vjp, ) # Slice group_sizes to just this shard's experts. When not using # EP, ``num_experts_local == self.num_experts`` so this is a no-op. - group_sizes = group_sizes[:local_tokens_per_expert_count] + group_sizes = group_sizes[:num_experts_local] # ``local_real_size = sum(group_sizes)`` is the number of permuted # rows that actually correspond to tokens routed to this shard's # experts. Used by the ring-EP caller to zero out garbage rows @@ -779,7 +646,7 @@ def _dispatch_and_expert_ffn( probs=sparse_probs, align_size=align_size_arg, ) - group_sizes = group_sizes[:local_tokens_per_expert_count] + group_sizes = group_sizes[:num_experts_local] combine_state = { "backend": "triton", "row_id_map": row_id_map, @@ -842,7 +709,7 @@ def _dispatch_and_expert_ffn( quantizer_set=q_set_w1, ) - act_fn = _get_activation_fn(self.activation_type) + act_fn = _convert_to_activation_function(self.activation_type) intermediate = act_fn(layer_w0) * layer_w1 expert_outputs = grouped_dense( @@ -868,14 +735,13 @@ def _combine( sequence_length: int, ) -> jnp.ndarray: if combine_state["backend"] == "pure_jax": - return mt_token_combine( + return unfused_token_combine( expert_outputs, combine_state["perm_state"], combine_state["routing_weights"], num_experts_per_tok=self.num_experts_per_tok, batch_size=batch_size, sequence_length=sequence_length, - use_custom_sort_vjp=self.use_custom_sort_vjp, ) # triton out_2d = token_combine( diff --git a/transformer_engine/jax/mt_permutation.py b/transformer_engine/jax/mt_permutation.py deleted file mode 100644 index 10882501ec..0000000000 --- a/transformer_engine/jax/mt_permutation.py +++ /dev/null @@ -1,356 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Pure-JAX MoE Permutation API. - -This module provides a MaxText-style, pure-JAX implementation of MoE token -dispatch / combine as an alternative to the Triton-backed primitives in -``transformer_engine.jax.permutation``. Empirically this path has been faster -than the Triton kernels on several E2E workloads. - -The core design mirrors Maxtext's ``_mt_permute`` / ``_mt_unpermute`` in -``maxtext/src/maxtext/layers/moe.py``, with alignment-padding support ported -from `nvjax-svc-0/maxtext PR #36 `_ -so each expert's group size is a multiple of ``align_size`` (required for -quantized grouped GEMM whose recipe-specific alignment must divide -``align_size``). - -When ``align_size = 0`` padding is disabled (faster for the unquantized path); -when ``align_size > 0`` a static-size padding buffer of shape -``[num_experts * (align_size - 1)]`` is appended before the sort so the overall -shape is JIT-compatible. - -The public API is: - -* :func:`mt_token_dispatch` -- pure-JAX counterpart of ``token_dispatch``. -* :func:`mt_token_combine` -- pure-JAX counterpart of ``token_combine``. -* :class:`MTPermState` -- opaque state returned by ``mt_token_dispatch`` and - consumed by ``mt_token_combine``. -""" - -from typing import NamedTuple, Optional, Tuple - -import jax -import jax.numpy as jnp - -__all__ = [ - "MTPermState", - "mt_token_dispatch", - "mt_token_combine", -] - - -# ============================================================================= -# Custom-VJP argsort-based gather (``_sort_activations_custom``) -# ============================================================================= -# -# ``inputs[sort_indices]`` has a known inverse: ``output[argsort(sort_indices)]``. -# Using a custom VJP lets the backward pass exploit that inverse instead of -# relying on the compiler to discover it from the scatter-style default -# gradient of a gather, which is typically less efficient. - - -@jax.custom_vjp -def _sort_activations_custom(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array: - """Sort ``inputs`` along the leading dim by ``sort_indices``.""" - return inputs[sort_indices, ...] - - -def _sort_activations_custom_fwd( - inputs: jax.Array, sort_indices: jax.Array -) -> Tuple[jax.Array, jax.Array]: - return _sort_activations_custom(inputs, sort_indices), sort_indices - - -def _sort_activations_custom_bwd( - residuals: jax.Array, grads: jax.Array -) -> Tuple[jax.Array, None]: - sort_indices = residuals - # Inverse permutation: gather-by-argsort undoes the forward gather. - return _sort_activations_custom(grads, jnp.argsort(sort_indices)), None - - -_sort_activations_custom.defvjp(_sort_activations_custom_fwd, _sort_activations_custom_bwd) - - -def _sort_activations( - inputs: jax.Array, - sort_indices: jax.Array, - use_custom_vjp: bool, -) -> jax.Array: - """Sort activations by ``sort_indices``, optionally with the custom VJP.""" - assert inputs.shape[0] == sort_indices.shape[0], ( - f"inputs.shape[0]={inputs.shape[0]} must match" - f" sort_indices.shape[0]={sort_indices.shape[0]}" - ) - with jax.named_scope("mt_sort_activations"): - if use_custom_vjp: - return _sort_activations_custom(inputs, sort_indices) - return inputs[sort_indices, ...] - - -# ============================================================================= -# Permutation state carried from dispatch to combine -# ============================================================================= - - -class MTPermState(NamedTuple): - """Opaque state produced by :func:`mt_token_dispatch`. - - Attributes - ---------- - sorted_indices : jnp.ndarray - The argsort indices used in the forward sort. Needed to reverse the - permutation in :func:`mt_token_combine`. Shape - ``[num_real_tokens + padding_size]``. - num_real_tokens : int - Number of real (non-padding) permuted tokens, i.e. - ``batch_size * sequence_length * num_experts_per_tok``. Compile-time - constant. - padding_size : int - Number of alignment-padding tokens appended to the sort buffer. Equals - ``num_experts * (align_size - 1)`` when ``align_size > 0``, else ``0``. - Compile-time constant. - """ - - sorted_indices: jax.Array - num_real_tokens: int - padding_size: int - - -# ============================================================================= -# Dispatch (permute) -# ============================================================================= - - -def mt_token_dispatch( - inputs: jnp.ndarray, - selected_experts: jnp.ndarray, - num_experts: int, - num_experts_per_tok: int, - align_size: int = 0, - roll_to_expert_id: Optional[int] = None, - use_custom_sort_vjp: bool = True, -) -> Tuple[jnp.ndarray, MTPermState, jnp.ndarray]: - """Pure-JAX MaxText-style token dispatch. - - Parameters - ---------- - inputs : jnp.ndarray - Input tensor of shape ``[num_tokens, hidden_size]`` (or - ``[batch, seq, hidden]``; it will be flattened). - selected_experts : jnp.ndarray - Per-token expert IDs, shape ``[num_tokens, num_experts_per_tok]`` (or - ``[batch, seq, num_experts_per_tok]``). Integer dtype. - num_experts : int - Total number of experts. - num_experts_per_tok : int - Top-k. Must equal ``selected_experts.shape[-1]``. - align_size : int, default 0 - Alignment for each expert's group size. ``0`` disables padding; a value - ``> 0`` appends a static-size padding buffer so each resulting group - size is a multiple of ``align_size``. - roll_to_expert_id : Optional[int] - If provided, rotates expert IDs by ``-roll_to_expert_id`` modulo - ``num_experts`` before the sort (ring-of-experts EP). The returned - ``group_sizes`` is rolled to match. - use_custom_sort_vjp : bool, default True - Whether to use the custom-VJP argsort gather for the sort. - - Returns - ------- - sorted_inputs : jnp.ndarray - Permuted tokens grouped by expert, shape - ``[num_real_tokens + padding_size, hidden_size]``. - perm_state : MTPermState - State needed by :func:`mt_token_combine`. - group_sizes : jnp.ndarray - Token count per expert, shape ``[num_experts]``. Each entry is a - multiple of ``align_size`` when ``align_size > 0``. - """ - assert num_experts_per_tok == selected_experts.shape[-1], ( - f"num_experts_per_tok={num_experts_per_tok} must match" - f" selected_experts.shape[-1]={selected_experts.shape[-1]}" - ) - assert align_size >= 0, f"align_size must be >= 0, got {align_size}" - - hidden_size = inputs.shape[-1] - # Flatten token dims. - inputs_2d = inputs.reshape(-1, hidden_size) - num_tokens = inputs_2d.shape[0] - num_real_tokens = num_tokens * num_experts_per_tok - - flatten_selected_experts = jnp.ravel(selected_experts) - - if align_size > 0: - # Per-expert token count, and how many extra tokens each expert needs - # to become aligned to ``align_size``. Using - # ``(align - count % align) % align`` gives 0 (not ``align``) when - # already aligned, so we never exceed the per-expert slot capacity of - # ``align_size - 1``. - token_count_per_expert = jnp.bincount( - flatten_selected_experts, length=num_experts - ) - padding_tokens_required_per_expert = ( - (align_size - (token_count_per_expert % align_size)) % align_size - ) - - # Build a static-size padding buffer of shape - # ``[num_experts * (align_size - 1)]``. Each expert ``i`` owns a slot - # of ``align_size - 1`` positions (worst-case padding, which occurs - # when ``token_count[i] % align_size == 1``). Within slot ``i``, - # positions ``[0, padding_needed)`` are assigned expert ``i`` and act - # as real padding; the rest are assigned to ``num_experts - 1`` as - # overflow placeholders that keep the buffer statically sized for JIT. - max_padding_per_expert = align_size - 1 - max_total_padding_size = num_experts * max_padding_per_expert - positions = jnp.arange(max_total_padding_size) - expert_for_pos = positions // max_padding_per_expert - offset_in_slot = positions % max_padding_per_expert - padding_needed = padding_tokens_required_per_expert[expert_for_pos] - flatten_padding_selected_experts = jnp.where( - offset_in_slot < padding_needed, - expert_for_pos, - num_experts - 1, - ) - - flatten_selected_experts = jnp.concatenate( - [flatten_selected_experts, flatten_padding_selected_experts], axis=0 - ) - - if roll_to_expert_id is not None: - flatten_selected_experts = ( - flatten_selected_experts - roll_to_expert_id - ) % num_experts - - sorted_selected_experts = jnp.argsort(flatten_selected_experts) - - replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) - # Pad inputs with zeros so the sort operand shape matches the expanded - # selected-experts vector. - replicated_inputs_2d = jnp.pad( - replicated_inputs_2d, - pad_width=((0, max_total_padding_size), (0, 0)), - mode="constant", - constant_values=0.0, - ) - - sorted_inputs = _sort_activations( - replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp - ) - - # Compute ``group_sizes`` directly from counts rather than via - # ``bincount(flatten_selected_experts)``: the overflow placeholder - # tokens would inflate ``group_sizes[num_experts - 1]``, breaking the - # alignment guarantee. Direct computation gives each expert exactly - # ``ceil(count / align) * align`` tokens. - group_sizes = token_count_per_expert + padding_tokens_required_per_expert - - if roll_to_expert_id is not None: - group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) - - padding_size = max_total_padding_size - else: - if roll_to_expert_id is not None: - flatten_selected_experts = ( - flatten_selected_experts - roll_to_expert_id - ) % num_experts - - sorted_selected_experts = jnp.argsort(flatten_selected_experts) - - replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) - sorted_inputs = _sort_activations( - replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp - ) - - group_sizes = jnp.bincount(flatten_selected_experts, length=num_experts) - if roll_to_expert_id is not None: - group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) - - padding_size = 0 - - perm_state = MTPermState( - sorted_indices=sorted_selected_experts, - num_real_tokens=num_real_tokens, - padding_size=padding_size, - ) - return sorted_inputs, perm_state, group_sizes - - -# ============================================================================= -# Combine (unpermute + weighted sum) -# ============================================================================= - - -def mt_token_combine( - expert_outputs: jnp.ndarray, - perm_state: MTPermState, - routing_weights: jnp.ndarray, - num_experts_per_tok: int, - batch_size: int, - sequence_length: int, - use_custom_sort_vjp: bool = True, -) -> jnp.ndarray: - """Pure-JAX MaxText-style token combine. - - Reverses the permutation performed by :func:`mt_token_dispatch`, strips - any alignment-padding rows appended during dispatch, and applies a - per-token weighted sum across the top-k experts. - - Parameters - ---------- - expert_outputs : jnp.ndarray - Output of the expert FFN, shape - ``[num_real_tokens + padding_size, hidden_size]``. - perm_state : MTPermState - State returned by :func:`mt_token_dispatch`. - routing_weights : jnp.ndarray - Top-k routing weights, shape ``[batch*seq, num_experts_per_tok]`` - (or broadcastable to it after a ``reshape``). - num_experts_per_tok : int - Top-k. - batch_size : int - Original batch size. - sequence_length : int - Original sequence length. - use_custom_sort_vjp : bool, default True - Whether to use the custom-VJP argsort gather for the unsort. - - Returns - ------- - output : jnp.ndarray - Combined output tensor of shape ``[batch_size, sequence_length, hidden_size]``. - """ - # Reverse the permutation: ``output[argsort(sorted_indices)]`` undoes - # ``input[sorted_indices]``. - unsort_intermediate = _sort_activations( - expert_outputs, - jnp.argsort(perm_state.sorted_indices), - use_custom_sort_vjp, - ) - - # Strip alignment padding tokens appended during dispatch. After unsorting, - # the first ``num_real_tokens`` rows hold the real per-(token, top-k) - # outputs; any trailing rows are padding placeholders (zeros) and must be - # discarded before the reshape below. - if perm_state.padding_size > 0: - unsort_intermediate = unsort_intermediate[: perm_state.num_real_tokens] - - hidden_size = unsort_intermediate.shape[-1] - reshaped_weights = jnp.reshape(routing_weights, (-1, num_experts_per_tok)) - reshaped_intermediate = jnp.reshape( - unsort_intermediate, (reshaped_weights.shape[0], num_experts_per_tok, hidden_size) - ) - - # Cast weights to match intermediate dtype (weighted sum happens in - # intermediate dtype; callers can upcast before calling if higher - # precision weight-sum is desired). - reshaped_weights = reshaped_weights.astype(reshaped_intermediate.dtype) - with jax.named_scope("mt_weight_sum"): - output = jnp.einsum( - "BKE,BK -> BE", - reshaped_intermediate, - reshaped_weights, - ) - return output.reshape(batch_size, sequence_length, hidden_size) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 81972aac0f..1a492ba186 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -7,6 +7,17 @@ This module provides high-level token dispatch and combine operations for Mixture of Experts (MoE) models with proper automatic differentiation support. +Two backends are offered: + +* Fused, Triton-backed ``token_dispatch`` / ``token_combine`` - uses the + Triton kernels in ``transformer_engine.jax.triton_extensions.permutation``. +* Unfused, pure-JAX ``unfused_token_dispatch`` / ``unfused_token_combine`` - + uses only ``jnp.argsort`` + gather and is therefore compiled as plain XLA. + +Both backends support optional alignment padding (``align_size > 0``) so each +expert's group size is a multiple of ``align_size``, which is required for +quantized grouped GEMMs. + Token Dispatch (Permute): - Forward: Permute tokens according to routing map (scatter to experts) - Backward: Unpermute gradients (gather from experts) @@ -17,7 +28,7 @@ """ from functools import partial -from typing import Optional, Tuple +from typing import NamedTuple, Optional, Tuple import jax import jax.numpy as jnp @@ -38,6 +49,9 @@ "token_dispatch", "token_combine", "sort_chunks_by_index", + "unfused_token_dispatch", + "unfused_token_combine", + "UnfusedPermState", ] @@ -655,3 +669,323 @@ def _sort_chunks_by_index_bwd_rule( _sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule) + + +# ============================================================================= +# Unfused (pure-JAX) token dispatch / combine +# ============================================================================= +# +# The following implementations use only ``jnp.argsort`` + gather and compile +# to plain XLA. They are a drop-in alternative to ``token_dispatch`` / +# ``token_combine`` above, differing only in input/output conventions (the +# fused path takes ``routing_map`` and ``sparse_probs`` over all experts; the +# unfused path takes dense ``selected_experts`` and per-token ``weights`` of +# shape ``[..., topk]``). + + +# ----------------------------------------------------------------------------- +# Custom-VJP argsort-based gather. +# +# ``inputs[sort_indices]`` has a known inverse: ``output[argsort(sort_indices)]``. +# Using a custom VJP lets the backward pass exploit that inverse instead of +# relying on the compiler to discover it from the scatter-style default +# gradient of a gather, which is typically less efficient. + + +@jax.custom_vjp +def _sort_activations(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array: + """Sort ``inputs`` along the leading dim by ``sort_indices``.""" + assert inputs.shape[0] == sort_indices.shape[0], ( + f"inputs.shape[0]={inputs.shape[0]} must match" + f" sort_indices.shape[0]={sort_indices.shape[0]}" + ) + with jax.named_scope("unfused_sort_activations"): + return inputs[sort_indices, ...] + + +def _sort_activations_fwd( + inputs: jax.Array, sort_indices: jax.Array +) -> Tuple[jax.Array, jax.Array]: + return _sort_activations(inputs, sort_indices), sort_indices + + +def _sort_activations_bwd( + residuals: jax.Array, grads: jax.Array +) -> Tuple[jax.Array, None]: + sort_indices = residuals + # Inverse permutation: gather-by-argsort undoes the forward gather. + return _sort_activations(grads, jnp.argsort(sort_indices)), None + + +_sort_activations.defvjp(_sort_activations_fwd, _sort_activations_bwd) + + +def _routing_map_to_selected_experts( + sparse_probs: jnp.ndarray, + routing_map: jnp.ndarray, + topk: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Convert ``(sparse_probs, routing_map)`` from TE's fused router to the + ``(selected_experts, weights)`` format consumed by + :func:`unfused_token_dispatch`. + + ``routing_map`` is a boolean mask of shape ``[num_tokens, num_experts]`` + with exactly ``topk`` ``True`` positions per row. + """ + # Argsort on a bool tensor places ``True`` rows last (False=0 < True=1), + # so the last ``topk`` indices are the selected expert IDs. + selected_experts = jnp.argsort(routing_map, axis=-1)[..., -topk:] + weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) + return selected_experts, weights + + +# ----------------------------------------------------------------------------- +# Permutation state carried from dispatch to combine. + + +class UnfusedPermState(NamedTuple): + """Opaque state produced by :func:`unfused_token_dispatch`. + + Attributes + ---------- + sorted_indices : jnp.ndarray + The argsort indices used in the forward sort. Needed to reverse the + permutation in :func:`unfused_token_combine`. Shape + ``[num_real_tokens + padding_size]``. + num_real_tokens : int + Number of real (non-padding) permuted tokens, i.e. + ``batch_size * sequence_length * num_experts_per_tok``. Compile-time + constant. + padding_size : int + Number of alignment-padding tokens appended to the sort buffer. Equals + ``num_experts * (align_size - 1)`` when ``align_size > 0``, else ``0``. + Compile-time constant. + """ + + sorted_indices: jax.Array + num_real_tokens: int + padding_size: int + + +# ----------------------------------------------------------------------------- +# Dispatch (permute) + + +def unfused_token_dispatch( + inputs: jnp.ndarray, + selected_experts: jnp.ndarray, + num_experts: int, + num_experts_per_tok: int, + align_size: int = 0, + roll_to_expert_id: Optional[int] = None, +) -> Tuple[jnp.ndarray, UnfusedPermState, jnp.ndarray]: + """Pure-JAX ``argsort``-based token dispatch. + + Parameters + ---------- + inputs : jnp.ndarray + Input tensor of shape ``[num_tokens, hidden_size]`` (or + ``[batch, seq, hidden]``; it will be flattened). + selected_experts : jnp.ndarray + Per-token expert IDs, shape ``[num_tokens, num_experts_per_tok]`` (or + ``[batch, seq, num_experts_per_tok]``). Integer dtype. + num_experts : int + Total number of experts. + num_experts_per_tok : int + Top-k. Must equal ``selected_experts.shape[-1]``. + align_size : int, default 0 + Alignment for each expert's group size. ``0`` disables padding; a value + ``> 0`` appends a static-size padding buffer so each resulting group + size is a multiple of ``align_size`` (required for quantized grouped + GEMM). + roll_to_expert_id : Optional[int] + If provided, rotates expert IDs by ``-roll_to_expert_id`` modulo + ``num_experts`` before the sort (ring-of-experts EP). The returned + ``group_sizes`` is rolled to match. + + Returns + ------- + sorted_inputs : jnp.ndarray + Permuted tokens grouped by expert, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : UnfusedPermState + State needed by :func:`unfused_token_combine`. + group_sizes : jnp.ndarray + Token count per expert, shape ``[num_experts]``. Each entry is a + multiple of ``align_size`` when ``align_size > 0``. + """ + assert num_experts_per_tok == selected_experts.shape[-1], ( + f"num_experts_per_tok={num_experts_per_tok} must match" + f" selected_experts.shape[-1]={selected_experts.shape[-1]}" + ) + assert align_size >= 0, f"align_size must be >= 0, got {align_size}" + + hidden_size = inputs.shape[-1] + inputs_2d = inputs.reshape(-1, hidden_size) + num_tokens = inputs_2d.shape[0] + num_real_tokens = num_tokens * num_experts_per_tok + + flatten_selected_experts = jnp.ravel(selected_experts) + + if align_size > 0: + # Per-expert token count, and how many extra tokens each expert needs + # to become aligned to ``align_size``. Using + # ``(align - count % align) % align`` gives 0 (not ``align``) when + # already aligned, so we never exceed the per-expert slot capacity of + # ``align_size - 1``. + token_count_per_expert = jnp.bincount( + flatten_selected_experts, length=num_experts + ) + padding_tokens_required_per_expert = ( + (align_size - (token_count_per_expert % align_size)) % align_size + ) + + # Build a static-size padding buffer of shape + # ``[num_experts * (align_size - 1)]``. Each expert ``i`` owns a slot + # of ``align_size - 1`` positions (worst-case padding, which occurs + # when ``token_count[i] % align_size == 1``). Within slot ``i``, + # positions ``[0, padding_needed)`` are assigned expert ``i`` and act + # as real padding; the rest are assigned to ``num_experts - 1`` as + # overflow placeholders that keep the buffer statically sized for JIT. + max_padding_per_expert = align_size - 1 + max_total_padding_size = num_experts * max_padding_per_expert + positions = jnp.arange(max_total_padding_size) + expert_for_pos = positions // max_padding_per_expert + offset_in_slot = positions % max_padding_per_expert + padding_needed = padding_tokens_required_per_expert[expert_for_pos] + flatten_padding_selected_experts = jnp.where( + offset_in_slot < padding_needed, + expert_for_pos, + num_experts - 1, + ) + + flatten_selected_experts = jnp.concatenate( + [flatten_selected_experts, flatten_padding_selected_experts], axis=0 + ) + + if roll_to_expert_id is not None: + flatten_selected_experts = ( + flatten_selected_experts - roll_to_expert_id + ) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + # Pad inputs with zeros so the sort operand shape matches the expanded + # selected-experts vector. + replicated_inputs_2d = jnp.pad( + replicated_inputs_2d, + pad_width=((0, max_total_padding_size), (0, 0)), + mode="constant", + constant_values=0.0, + ) + + sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts) + + # Compute ``group_sizes`` directly from counts rather than via + # ``bincount(flatten_selected_experts)``: the overflow placeholder + # tokens would inflate ``group_sizes[num_experts - 1]``, breaking the + # alignment guarantee. Direct computation gives each expert exactly + # ``ceil(count / align) * align`` tokens. + group_sizes = token_count_per_expert + padding_tokens_required_per_expert + + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = max_total_padding_size + else: + if roll_to_expert_id is not None: + flatten_selected_experts = ( + flatten_selected_experts - roll_to_expert_id + ) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts) + + group_sizes = jnp.bincount(flatten_selected_experts, length=num_experts) + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = 0 + + perm_state = UnfusedPermState( + sorted_indices=sorted_selected_experts, + num_real_tokens=num_real_tokens, + padding_size=padding_size, + ) + return sorted_inputs, perm_state, group_sizes + + +# ----------------------------------------------------------------------------- +# Combine (unpermute + weighted sum) + + +def unfused_token_combine( + expert_outputs: jnp.ndarray, + perm_state: UnfusedPermState, + routing_weights: jnp.ndarray, + num_experts_per_tok: int, + batch_size: int, + sequence_length: int, +) -> jnp.ndarray: + """Pure-JAX ``argsort``-based token combine. + + Reverses the permutation performed by :func:`unfused_token_dispatch`, + strips any alignment-padding rows appended during dispatch, and applies a + per-token weighted sum across the top-k experts. + + Parameters + ---------- + expert_outputs : jnp.ndarray + Output of the expert FFN, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : UnfusedPermState + State returned by :func:`unfused_token_dispatch`. + routing_weights : jnp.ndarray + Top-k routing weights, shape ``[batch*seq, num_experts_per_tok]`` + (or broadcastable to it after a ``reshape``). + num_experts_per_tok : int + Top-k. + batch_size : int + Original batch size. + sequence_length : int + Original sequence length. + + Returns + ------- + output : jnp.ndarray + Combined output tensor of shape ``[batch_size, sequence_length, hidden_size]``. + """ + # Reverse the permutation: ``output[argsort(sorted_indices)]`` undoes + # ``input[sorted_indices]``. + unsort_intermediate = _sort_activations( + expert_outputs, + jnp.argsort(perm_state.sorted_indices), + ) + + # Strip alignment padding tokens appended during dispatch. After unsorting, + # the first ``num_real_tokens`` rows hold the real per-(token, top-k) + # outputs; any trailing rows are padding placeholders (zeros) and must be + # discarded before the reshape below. + if perm_state.padding_size > 0: + unsort_intermediate = unsort_intermediate[: perm_state.num_real_tokens] + + hidden_size = unsort_intermediate.shape[-1] + reshaped_weights = jnp.reshape(routing_weights, (-1, num_experts_per_tok)) + reshaped_intermediate = jnp.reshape( + unsort_intermediate, (reshaped_weights.shape[0], num_experts_per_tok, hidden_size) + ) + + # Cast weights to match intermediate dtype (weighted sum happens in + # intermediate dtype; callers can upcast before calling if higher + # precision weight-sum is desired). + reshaped_weights = reshaped_weights.astype(reshaped_intermediate.dtype) + with jax.named_scope("unfused_weight_sum"): + output = jnp.einsum( + "BKE,BK -> BE", + reshaped_intermediate, + reshaped_weights, + ) + return output.reshape(batch_size, sequence_length, hidden_size) From 1b5c36aa87a07a9c63f4ca7b902f7aaf37e60c87 Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 22 Apr 2026 17:58:31 -0700 Subject: [PATCH 4/6] add distributed test. Signed-off-by: tdophung --- tests/jax/test_distributed_moe_block.py | 143 ++++++++++++++++++++++++ tests/jax/test_moe_block.py | 23 +++- transformer_engine/jax/flax/moe.py | 24 ++-- 3 files changed, 180 insertions(+), 10 deletions(-) create mode 100644 tests/jax/test_distributed_moe_block.py diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py new file mode 100644 index 0000000000..9d9e57140f --- /dev/null +++ b/tests/jax/test_distributed_moe_block.py @@ -0,0 +1,143 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed tests for ``transformer_engine.jax.flax.MoEBlock``.""" + +import sys + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from jax.sharding import Mesh, PartitionSpec + +from utils import assert_allclose, is_devices_enough + + +@pytest.fixture(autouse=True, scope="function") +def _inject_moe(request): + """Lazy-load ``MoEBlock`` only for tests marked ``triton``.""" + if not request.node.get_closest_marker("triton"): + yield + return + + from transformer_engine.jax import MeshResource, autocast + from transformer_engine.jax.flax import MoEBlock + + mod = sys.modules[__name__] + mod.MeshResource = MeshResource + mod.autocast = autocast + mod.MoEBlock = MoEBlock + yield + + +DTYPE = jnp.bfloat16 +BATCH_SIZE = 2 +SEQUENCE_LENGTH = 16 +HIDDEN_SIZE = 64 +INTERMEDIATE_SIZE = 128 +NUM_EXPERTS = 8 +NUM_EXPERTS_PER_TOK = 2 + + +def _make_inputs(key: jax.Array) -> jax.Array: + return jax.random.normal( + key, (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=DTYPE + ) + + +def _unwrap_partitioned(x): + return x.value if hasattr(x, "value") else x + + +@pytest.mark.triton +class TestDistributedMoEBlock: + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_ep2_fsdp2_matches_single_device(self, permutation_backend): + if not is_devices_enough(4): + pytest.skip("MoE distributed test requires 4 devices for EP=2 x FSDP=2.") + + key = jax.random.PRNGKey(11) + init_key, data_key = jax.random.split(key) + inputs = _make_inputs(data_key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + aux_loss_coeff=1e-2, + dtype=DTYPE, + ) + + single_block = MoEBlock(**base_kwargs) + + def loss_fn(block, variables, x): + output, aux_loss = block.apply(variables, x) + loss = jnp.mean(output.astype(jnp.float32) ** 2) + if aux_loss is not None: + loss = loss + aux_loss.astype(jnp.float32) + return loss, (output, aux_loss) + + with autocast(enabled=False, mesh_resource=MeshResource()): + single_variables = single_block.init(init_key, inputs) + (single_loss, (single_output, single_aux)), single_grads = jax.value_and_grad( + loss_fn, argnums=1, has_aux=True + )(single_block, single_variables, inputs) + + devices = np.asarray(jax.devices()[:4]).reshape(2, 2) + mesh = Mesh(devices, ("ep", "fsdp")) + # FSDP-style sharding: weights are sharded on a *non-contracting* + # weight axis (gathered before the GEMM); activations stay sharded on + # the *batch* axis throughout - the same fsdp mesh axis is reused for + # both. The TE primitives' custom_partitioning rules expect activations + # FSDP-sharded on batch, so we declare ("batch", "fsdp") AND pass + # ``input_axes=("batch", None, None)`` to enforce it on the inputs to + # the block. ("embed", "fsdp") shards the weight's hidden dim, which + # is gathered inside grouped_dense's custom_partitioning before GEMM + # (no reshard of activations needed because their layout is unchanged). + logical_axis_rules = ( + ("exp", "ep"), + ("batch", "fsdp"), + ("embed", "fsdp"), + ) + sharded_block = MoEBlock( + expert_parallelism_axis="ep", + mesh=mesh, + input_axes=("batch", None, None), + **base_kwargs, + ) + + with mesh, autocast(enabled=False, mesh_resource=MeshResource(fsdp_resource="fsdp")): + with nn.logical_axis_rules(logical_axis_rules): + sharded_variables = sharded_block.init(init_key, inputs) + (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = ( + jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( + sharded_block, sharded_variables, inputs + ) + ) + + wi_0 = _unwrap_partitioned(sharded_variables["params"]["wi_0"]) + wi_1 = _unwrap_partitioned(sharded_variables["params"]["wi_1"]) + wo = _unwrap_partitioned(sharded_variables["params"]["wo"]) + assert wi_0.sharding.spec == PartitionSpec("ep", "fsdp", None) + assert wi_1.sharding.spec == PartitionSpec("ep", "fsdp", None) + assert wo.sharding.spec == PartitionSpec("ep", None, "fsdp") + + assert_allclose(sharded_output, single_output, dtype=DTYPE, atol=5e-2, rtol=5e-2) + assert_allclose(sharded_loss, single_loss, dtype=jnp.float32, atol=5e-2, rtol=5e-2) + assert_allclose(sharded_aux, single_aux, dtype=jnp.float32, atol=5e-2, rtol=5e-2) + + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + grad_single = _unwrap_partitioned(single_grads["params"][name]) + grad_sharded = _unwrap_partitioned(sharded_grads["params"][name]) + assert_allclose( + grad_sharded, + grad_single, + dtype=DTYPE, + atol=1e-1, + rtol=1e-1, + err_msg=f"Distributed gradient mismatch for {name}", + ) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py index 458d674c7d..45cce2a60c 100644 --- a/tests/jax/test_moe_block.py +++ b/tests/jax/test_moe_block.py @@ -79,6 +79,11 @@ def _init_and_apply( return variables, output, aux_loss +def _unwrap_partitioned(x): + """Strip Flax logical-partition wrappers for numeric assertions.""" + return x.value if hasattr(x, "value") else x + + # ----------------------------------------------------------------------------- # Tests # ----------------------------------------------------------------------------- @@ -132,7 +137,7 @@ def loss_fn(variables, inputs): grads = jax.grad(loss_fn)(variables, inputs) # All trainable kernels should receive a non-trivial gradient. for name in ("gate_kernel", "wi_0", "wi_1", "wo"): - g = grads["params"][name] + g = _unwrap_partitioned(grads["params"][name]) assert jnp.all(jnp.isfinite(g)), f"{name} gradient has NaN/Inf" assert jnp.any(g != 0.0), f"{name} gradient is identically zero" @@ -183,8 +188,8 @@ def loss_fn(block, variables, inputs): assert jnp.allclose(loss_pj, loss_tr, atol=atol_out, rtol=rtol_out) for name in ("gate_kernel", "wi_0", "wi_1", "wo"): - g_pj = grads_pj["params"][name] - g_tr = grads_tr["params"][name] + g_pj = _unwrap_partitioned(grads_pj["params"][name]) + g_tr = _unwrap_partitioned(grads_tr["params"][name]) assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), ( f"Gradient for {name} differs across backends: max diff" f" {jnp.max(jnp.abs(g_pj - g_tr))}" @@ -238,6 +243,18 @@ def test_group_topk_deepseek(self, permutation_backend): assert output.shape == inputs.shape assert jnp.all(jnp.isfinite(output)) + @pytest.mark.xfail( + reason=( + "TE grouped_dense FFI currently asserts sum(group_sizes) == M " + "(see csrc/extensions/gemm.cpp). With align_size > 0 the dispatch " + "buffer is padded to a static worst-case size, so M can exceed " + "sum(group_sizes). The MoE block deliberately does not fold the " + "gap into a single expert (that would create per-shard load " + "imbalance under EP). Re-enable once the FFI check is relaxed to " + "M >= sum(group_sizes)." + ), + strict=False, + ) def test_align_size_equivalence_pure_jax(self): """For the pure-JAX backend, ``align_size > 0`` must not change the numerical output of the forward pass: padding tokens contribute zero diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 6673ac1a71..5f257dc577 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -425,7 +425,7 @@ def _forward_ring_ep( gate_logits: jnp.ndarray, params: dict, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Wrap :meth:`_forward_body` in a ring-of-experts ``shard_map``. + """Wrap ``_forward_body`` in a ring-of-experts ``shard_map``. For each EP shard the wrapper: 1. ``all_gather`` s the local inputs / logits / expert_bias along @@ -566,7 +566,7 @@ def _route( # The score-for-aux kernel runs independently (no data dependency # on the main kernel), so XLA can overlap them on the GPU. aux_scores, aux_routing_map = fused_topk_with_score_function( - logits_2d, + logits_2d.astype(jnp.float32), topk=self.num_experts_per_tok, score_function=self.score_function, compute_aux_scores=True, @@ -575,7 +575,7 @@ def _route( aux_routing_map.astype(jnp.int32), axis=0 ) aux_loss = fused_moe_aux_loss( - aux_scores, + aux_scores.astype(jnp.float32), aux_tokens_per_expert, topk=self.num_experts_per_tok, coeff=self.aux_loss_coeff, @@ -619,11 +619,21 @@ def _dispatch_and_expert_ffn( ) # Slice group_sizes to just this shard's experts. When not using # EP, ``num_experts_local == self.num_experts`` so this is a no-op. + # + # NOTE on padded buffers (``align_size > 0``): + # ``unfused_token_dispatch`` pads ``sorted_inputs`` to a static + # worst-case row count so JIT shape inference is happy. The + # returned ``group_sizes`` deliberately tracks only real + real + # alignment-padding tokens; the remaining rows are zero-input + # placeholders that ``grouped_dense`` does not need to touch. + # + # TE's ``grouped_dense`` FFI today asserts strictly + # ``sum(group_sizes) == sorted_inputs.shape[0]``. When that + # assertion is relaxed to ``>=`` (the GEMM only iterates over the + # first ``sum(group_sizes)`` rows anyway), this code works as-is. + # Folding the gap into a single expert would create a per-shard + # load imbalance and is intentionally avoided here. group_sizes = group_sizes[:num_experts_local] - # ``local_real_size = sum(group_sizes)`` is the number of permuted - # rows that actually correspond to tokens routed to this shard's - # experts. Used by the ring-EP caller to zero out garbage rows - # before combine. combine_state = { "backend": "pure_jax", "perm_state": perm_state, From 12b1251a049d7ce8c48b95a017f61e29fe7092d4 Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 29 Apr 2026 18:02:18 -0700 Subject: [PATCH 5/6] refactor to a2a from roe Signed-off-by: tdophung --- tests/jax/test_moe_block.py | 18 +- transformer_engine/jax/flax/moe.py | 945 +++++++++++++++----------- transformer_engine/jax/permutation.py | 336 +++++++++ 3 files changed, 908 insertions(+), 391 deletions(-) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py index 45cce2a60c..39a6bfd592 100644 --- a/tests/jax/test_moe_block.py +++ b/tests/jax/test_moe_block.py @@ -245,13 +245,17 @@ def test_group_topk_deepseek(self, permutation_backend): @pytest.mark.xfail( reason=( - "TE grouped_dense FFI currently asserts sum(group_sizes) == M " - "(see csrc/extensions/gemm.cpp). With align_size > 0 the dispatch " - "buffer is padded to a static worst-case size, so M can exceed " - "sum(group_sizes). The MoE block deliberately does not fold the " - "gap into a single expert (that would create per-shard load " - "imbalance under EP). Re-enable once the FFI check is relaxed to " - "M >= sum(group_sizes)." + "TE grouped_dense FFI asserts sum(group_sizes) == M at " + "transformer_engine/jax/csrc/extensions/gemm.cpp:1029. With " + "align_size > 0 both backends produce a buffer where M >= " + "sum(group_sizes) (the slack is structural padding for JIT). " + "The kernel itself iterates over per-expert m_i from " + "group_sizes via nvte_multi_tensor_gemm and never reads past " + "sum(group_sizes), so relaxing that assertion to " + "`m >= sum_group_sizes` is the cleanest fix. The MoE block " + "deliberately does not fold the gap into a single expert " + "(that would create per-shard load imbalance under EP). " + "Re-enable once the FFI check is relaxed." ), strict=False, ) diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 5f257dc577..690d804e38 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -6,8 +6,54 @@ This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE layer that wires together TE's fused router, a selectable token-dispatch backend -(pure-JAX ``unfused_*`` or fused Triton), TE's ``grouped_dense``, and optional -ring-of-experts Expert Parallelism. +(pure-JAX ``unfused_*`` or fused Triton), TE's ``grouped_dense``, and an +optional ragged-all-to-all (A2A / A2Av) expert-parallelism strategy. + +Architecture +------------ + +The MoEBlock is decomposed into orthogonal stages so the EP wrapper can +inject collectives between them: + +* ``_route``: gate logits -> top-k routing decisions (+ aux loss). +* ``_global_permute``: scatter tokens to experts; produces + ``[num_tokens*topk + maybe_padding, hidden]`` and + per-expert ``group_sizes`` of length ``num_experts``. +* ``_expert_ffn``: three ``grouped_dense`` calls + activation. Operates + on whatever ``(rows, group_sizes, n_groups)`` it is + handed -- agnostic to whether ``n_groups`` is the + global expert count (no-EP) or the local expert + count (A2A-EP). +* ``_global_combine``: inverse of ``_global_permute`` -- gather + weighted + sum across top-k experts. + +Two top-level forward variants compose those stages: + +* ``_forward_no_ep``: route -> permute -> ffn -> combine. Each TE + primitive's ``custom_partitioning`` rule handles + DP / FSDP / TP automatically. +* ``_forward_a2a_ep``: wraps the body in :func:`jax.shard_map` and inserts + ``all_gather(group_sizes)`` + forward + ``ragged_all_to_all`` + local permute around the + FFN, plus their inverses afterwards. This is the + only place ``shard_map`` is used; A2A is the + canonical EP strategy because the in-flight NCCL + EP component will require this same data layout. + +Note on ``align_size > 0`` +-------------------------- + +Both permutation backends pad each expert's group to a multiple of +``align_size`` when requested, which is what CUBLASLt's grouped GEMM wants +for FP8 shape selection. The pure-JAX backend additionally appends a +zero-input padding tail to keep the buffer statically sized for JIT, so +``sum(group_sizes) <= sorted_inputs.shape[0]`` strictly. TE's +``grouped_dense`` FFI today asserts ``m == sum(group_sizes)`` at +``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``; relaxing that +check to ``m >= sum(group_sizes)`` (the kernel itself only iterates over +``sum(group_sizes)`` rows via ``nvte_multi_tensor_gemm``) is the cleanest +way to support ``align_size > 0`` end-to-end. Until that lands the +``align_size > 0`` tests stay xfail. """ from typing import Any, Callable, NewType, Optional, Tuple, Union @@ -20,6 +66,10 @@ from ..dense import grouped_dense from ..permutation import ( _routing_map_to_selected_experts, + compute_ragged_all_to_all_params, + compute_reverse_ragged_all_to_all_params, + local_permute_after_a2a, + local_unpermute_before_a2a, token_combine, token_dispatch, unfused_token_combine, @@ -49,15 +99,28 @@ class MoEBlock(TransformerEngineBase): """Mixture-of-Experts Flax Linen block. Encapsulates the full MoE forward pass: gate projection, fused top-k - routing, optional auxiliary load-balancing loss, token dispatch, per-expert - two-layer FFN via grouped GEMMs, activation, token combine, and optional - ring-of-experts expert parallelism. - - The permutation step is pluggable via ``permutation_backend``: - ``"pure_jax"`` (default) uses the pure-JAX argsort-based - ``unfused_token_dispatch`` / ``unfused_token_combine`` in - :mod:`transformer_engine.jax.permutation`; ``"triton"`` uses TE's fused - ``token_dispatch`` / ``token_combine`` kernels. + routing, optional auxiliary load-balancing loss, token dispatch, + per-expert two-layer FFN via grouped GEMMs, activation, token combine, + and optional ragged-all-to-all expert parallelism. + + Two permutation backends are pluggable via ``permutation_backend``: + + * ``"pure_jax"`` (default) -- argsort-based + :func:`~transformer_engine.jax.permutation.unfused_token_dispatch` / + :func:`~transformer_engine.jax.permutation.unfused_token_combine`. + Faster than Triton in profiling for DeepSeek-style configs. + * ``"triton"`` -- TE's fused + :func:`~transformer_engine.jax.permutation.token_dispatch` / + :func:`~transformer_engine.jax.permutation.token_combine` Triton + kernels. + + Expert parallelism (``expert_parallelism_axis is not None``) uses the + **ragged-all-to-all** EP strategy (a.k.a. A2Av): each shard routes its + own tokens globally over all experts, then a forward + ``ragged_all_to_all`` exchanges per-expert chunks so each shard ends up + holding only the tokens for its local experts; after the FFN a reverse + ``ragged_all_to_all`` returns each shard's outputs to it. This matches + the layout the in-flight NCCL EP component expects. Parameters ---------- @@ -70,70 +133,72 @@ class MoEBlock(TransformerEngineBase): activation_type : str FFN activation applied to the gate projection. Paired with the up - projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. Resolved - via :func:`flax.linen.` (``"silu"``, ``"gelu"``, ``"relu"``, - ``"swish"``, ...) plus ``"linear"`` for identity. + projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. + Resolved via :func:`flax.linen.` (``"silu"``, ``"gelu"``, + ``"relu"``, ``"swish"``, ...) plus ``"linear"`` for identity. score_function : str or ScoreFunction - ``"softmax"`` (default) or ``"sigmoid"`` for :func:`fused_topk_with_score_function`. + ``"softmax"`` (default) or ``"sigmoid"`` for + :func:`fused_topk_with_score_function`. use_pre_softmax : bool Apply softmax before top-k when ``score_function="softmax"``. num_groups : int - Number of routing groups for grouped top-k (DeepSeek). ``<=0`` disables. + Number of routing groups for grouped top-k (DeepSeek). ``<=0`` + disables. group_topk : int Top-k at the group level. ``<=0`` disables. scaling_factor : float Scaling factor applied to output probs. use_expert_bias : bool - If ``True``, registers a learnable ``expert_bias`` parameter of shape - ``[num_experts]`` and passes it to the fused router. The router - primitive validates that this is paired with ``score_function="sigmoid"``. + If ``True``, registers a learnable ``expert_bias`` parameter of + shape ``[num_experts]`` and passes it to the fused router. The + router primitive validates that this is paired with + ``score_function="sigmoid"``. aux_loss_coeff : float - If ``> 0``, compute and return the MoE auxiliary load-balancing loss - scalar via :func:`fused_moe_aux_loss`. ``0`` disables. + If ``> 0``, compute and return the MoE auxiliary load-balancing + loss scalar via :func:`fused_moe_aux_loss`. ``0`` disables. gate_kernel_axes : tuple[str, ...] Logical partitioning axes for the gate kernel of shape ``[hidden, num_experts]``. wi_kernel_axes : tuple[str, ...] Logical partitioning axes for the ``wi_0`` and ``wi_1`` kernels of - shape ``[num_experts, hidden, intermediate]``. Default: + shape ``[num_experts, hidden, intermediate]``. Default ``("exp", "embed", "mlp")``. wo_kernel_axes : tuple[str, ...] Logical partitioning axes for the ``wo`` kernel of shape - ``[num_experts, intermediate, hidden]``. Default: + ``[num_experts, intermediate, hidden]``. Default ``("exp", "mlp", "embed")``. input_axes : tuple[str, ...] Logical axes used to constrain the input activation sharding at the block boundary. ``()`` (default) means no constraint. expert_parallelism_axis : Optional[str] - Mesh axis along which experts are split. When set, the forward pass - is wrapped in :func:`jax.experimental.shard_map.shard_map` that - implements the ring-of-experts EP strategy: ``all_gather`` on inputs - and gate logits, local routing + dispatch + FFN + combine, then - ``psum_scatter`` on the output. When ``None`` (default), no - ``shard_map`` wrapper is used; each primitive's ``custom_partitioning`` - rule handles DP/FSDP/TP automatically. + Mesh axis along which experts are split. When set, the forward + pass is wrapped in :func:`jax.shard_map` that implements the + ragged-all-to-all EP strategy. When ``None`` (default), no + ``shard_map`` wrapper is used; each TE primitive's + ``custom_partitioning`` rule handles DP / FSDP / TP automatically. tensor_parallelism_axis : Optional[str] Mesh axis for tensor parallelism on the FFN intermediate dim. When set, the output of the ``wo`` grouped GEMM is ``psum_scatter`` ed - along this axis (inside the ``shard_map`` when EP is enabled, else at - the end of the forward pass). + along this axis. permutation_backend : str ``"pure_jax"`` (default) or ``"triton"``. align_size : int Alignment for per-expert group sizes after padding. ``0`` disables - padding (faster for the unquantized path). ``>0`` is required for - quantized TE grouped GEMM whose recipe-specific alignment must divide - ``align_size``. + padding (the only supported configuration end-to-end today). ``>0`` + is required for quantized TE grouped GEMM whose recipe-specific + alignment must divide ``align_size``; see the module docstring for + the FFI assertion that currently blocks ``>0`` for both backends. dtype : jnp.dtype Compute and parameter dtype. kernel_init : Initializer Initializer for all kernels (gate + per-expert FFN). Defaults to - ``variance_scaling(1.0, 'fan_in', 'truncated_normal')`` (Flax convention). + ``variance_scaling(1.0, 'fan_in', 'truncated_normal')`` (Flax + convention). use_bias : bool If ``True``, registers per-expert FFN biases ``wi_0_bias``, ``wi_1_bias``, ``wo_bias``. @@ -198,7 +263,7 @@ def __post_init__(self): # Parameter registration # ------------------------------------------------------------------ - def _make_params(self, hidden_size: int): + def _make_params(self, hidden_size: int) -> dict: """Register module parameters and return them as a dict.""" gate_kernel = self.param( "gate_kernel", @@ -224,7 +289,7 @@ def _make_params(self, hidden_size: int): (self.num_experts, self.intermediate_size, hidden_size), self.dtype, ) - params = { + params: dict = { "gate_kernel": gate_kernel, "wi_0": wi_0, "wi_1": wi_1, @@ -276,8 +341,8 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: output : jnp.ndarray Output tensor of shape ``[batch, sequence, hidden]``. aux_loss : Optional[jnp.ndarray] - Scalar auxiliary load-balancing loss when ``aux_loss_coeff > 0``, - else ``None``. + Scalar auxiliary load-balancing loss when + ``aux_loss_coeff > 0``, else ``None``. """ assert inputs.ndim == 3, ( f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" @@ -287,27 +352,15 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: _, _, hidden_size = inputs.shape params = self._make_params(hidden_size) - # Gate runs OUTSIDE the EP shard_map below, so each EP shard projects - # its own local slice of tokens and we later all-gather only the - # smaller logits tensor instead of the full inputs. + # The gate runs OUTSIDE any EP shard_map: under EP each shard + # projects only its local slice of tokens, producing local gate + # logits with the same per-shard layout as ``inputs``. gate_logits = self._gate(inputs, params["gate_kernel"]) if self.expert_parallelism_axis is None: - # No EP: each primitive's own ``custom_partitioning`` rule handles - # DP / FSDP / TP across the mesh - no shard_map needed. - output, aux_loss = self._forward_body( - inputs, - gate_logits, - params, - num_experts_local=self.num_experts, - roll_to_expert_id=None, - ) + output, aux_loss = self._forward_no_ep(inputs, gate_logits, params) else: - # Ring-EP: ``_forward_body`` is wrapped in a shard_map that - # orchestrates the cross-primitive collectives (all_gather inputs - # / logits before, psum_scatter output after) which per-primitive - # ``custom_partitioning`` cannot express on its own. - output, aux_loss = self._forward_ring_ep(inputs, gate_logits, params) + output, aux_loss = self._forward_a2a_ep(inputs, gate_logits, params) if self.aux_loss_coeff <= 0.0: aux_loss = None @@ -320,235 +373,31 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray: """Linear gate projection ``inputs @ gate_kernel``. - Kept as a plain matmul (not ``DenseGeneral``) so it integrates cleanly - with the EP shard_map: the gate matmul runs in the outer (pre-shard_map) - scope and its output is all-gathered along the EP axis inside. + Kept as a plain ``einsum`` (not ``DenseGeneral``) so it composes + cleanly with the EP shard_map: the gate runs in the outer + (pre-shard_map) scope and its output passes through the + ``shard_map`` boundary unchanged. """ - # Cast kernel to input dtype outside FP8 scope (gate is typically BF16/FP32). kernel = gate_kernel.astype(inputs.dtype) return jnp.einsum("bsh,he->bse", inputs, kernel) - # ------------------------------------------------------------------ - # Forward body (shared between no-EP and ring-EP paths) - # ------------------------------------------------------------------ - - def _forward_body( - self, - inputs: jnp.ndarray, - gate_logits: jnp.ndarray, - params: dict, - num_experts_local: int, - roll_to_expert_id: Optional[int], - ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Routing + dispatch + per-expert FFN + combine. - - Used both bare (no EP) and inside the ring-EP shard_map. In the - ring-EP case ``inputs`` and ``gate_logits`` are the post-all_gather - global tensors, ``num_experts_local == num_experts // num_ep``, and - ``roll_to_expert_id`` is the offset that brings this shard's experts - into slots ``[0, num_experts_local)``. - """ - batch_size, sequence_length, hidden_size = inputs.shape - inputs_2d = inputs.reshape(-1, hidden_size) - logits_2d = gate_logits.reshape(-1, self.num_experts) - - sparse_probs, routing_map, aux_loss = self._route( - logits_2d, params.get("expert_bias") - ) - - if roll_to_expert_id is not None: - # Rotate expert columns so this shard's experts come first. - routing_map = jnp.roll(routing_map, -roll_to_expert_id, axis=-1) - sparse_probs = jnp.roll(sparse_probs, -roll_to_expert_id, axis=-1) - if self.permutation_backend == "triton": - # Triton path: zero out remote-expert columns so the fused - # ``token_dispatch`` never writes tokens routed off-shard. - # The pure-JAX path zeroes garbage *output* rows below - # instead, since masking the routing_map directly would - # break the argsort-based permutation. - local_mask = ( - jnp.arange(self.num_experts) < num_experts_local - ) - routing_map = routing_map * local_mask - sparse_probs = sparse_probs * local_mask.astype(sparse_probs.dtype) - - expert_outputs, combine_state = self._dispatch_and_expert_ffn( - inputs_2d, - sparse_probs, - routing_map, - params, - num_experts_local=num_experts_local, - # The roll is already baked into ``routing_map``/``sparse_probs`` - # above, so the unfused dispatch must not roll again. - roll_to_expert_id=0 if roll_to_expert_id is not None else None, - ) - - if ( - roll_to_expert_id is not None - and self.permutation_backend == "pure_jax" - ): - # Zero the rows of ``expert_outputs`` past the real local-expert - # token count: ``grouped_dense`` leaves them as garbage because - # ``group_sizes`` was truncated to the local slice. Without this - # the unsort + weighted-sum in combine would mix garbage into - # every token's output (mirrors Maxtext's moe.py). - real_mask = ( - jnp.arange(expert_outputs.shape[0]) - < combine_state["local_real_size"] - ) - expert_outputs = jnp.where(real_mask[:, None], expert_outputs, 0) - - output = self._combine( - expert_outputs, - combine_state, - batch_size=batch_size, - sequence_length=sequence_length, - ) - - if self.tensor_parallelism_axis is not None: - output = jax.lax.psum_scatter( - output, - self.tensor_parallelism_axis, - scatter_dimension=2, - tiled=True, - ) - - return output, aux_loss - - # ------------------------------------------------------------------ - # Ring-of-Experts EP wrapper - # ------------------------------------------------------------------ - - def _forward_ring_ep( - self, - inputs: jnp.ndarray, - gate_logits: jnp.ndarray, - params: dict, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Wrap ``_forward_body`` in a ring-of-experts ``shard_map``. - - For each EP shard the wrapper: - 1. ``all_gather`` s the local inputs / logits / expert_bias along - the EP axis so the routing sees every token globally. - 2. Calls ``_forward_body`` with ``roll_to_expert_id = - num_experts_per_shard * shard_id`` and the EP-local weight slice. - 3. ``psum_scatter`` s the resulting ``[B*num_ep, S, H]`` output back - to the EP-sharded ``[B, S, H]`` layout. - """ - from jax.experimental.shard_map import shard_map - - ep_axis = self.expert_parallelism_axis - if self.mesh is None: - raise ValueError( - "MoEBlock.expert_parallelism_axis is set; `mesh` must also be" - " provided so the ring-of-experts shard_map can be built." - ) - mesh = self.mesh - num_ep = mesh.shape[ep_axis] - assert self.num_experts % num_ep == 0, ( - f"num_experts={self.num_experts} must be divisible by EP size={num_ep}" - ) - num_experts_per_shard = self.num_experts // num_ep - - # Pack everything that crosses the shard_map boundary into a dict - # pytree. shard_map fully supports pytrees: ``in_specs`` must - # structurally match ``captured``, and we build them in lockstep so - # adding/removing an optional bias is a single ``dict[name] = ...``. - captured: dict = { - "inputs": inputs, - "gate_logits": gate_logits, - "wi_0": params["wi_0"], - "wi_1": params["wi_1"], - "wo": params["wo"], - } - in_specs: dict = { - "inputs": P(ep_axis, None, None), - "gate_logits": P(ep_axis, None, None), - "wi_0": P(ep_axis, None, None), - "wi_1": P(ep_axis, None, None), - "wo": P(ep_axis, None, None), - } - if "expert_bias" in params: - captured["expert_bias"] = params["expert_bias"] - in_specs["expert_bias"] = P(ep_axis) - if "wi_0_bias" in params: - for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): - captured[name] = params[name] - in_specs[name] = P(ep_axis, None) - - def _ring_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: - shard_id = jax.lax.axis_index(ep_axis) - - gathered_inputs = jax.lax.all_gather( - local["inputs"], axis_name=ep_axis, tiled=True - ) - gathered_logits = jax.lax.all_gather( - local["gate_logits"], axis_name=ep_axis, tiled=True - ) - - local_params: dict = { - "wi_0": local["wi_0"], - "wi_1": local["wi_1"], - "wo": local["wo"], - } - if "expert_bias" in local: - # The router operates over the full expert axis, so the - # EP-sharded bias must be all-gathered. - local_params["expert_bias"] = jax.lax.all_gather( - local["expert_bias"], axis_name=ep_axis, tiled=True - ) - if "wi_0_bias" in local: - local_params["wi_0_bias"] = local["wi_0_bias"] - local_params["wi_1_bias"] = local["wi_1_bias"] - local_params["wo_bias"] = local["wo_bias"] - - output, aux_loss = self._forward_body( - gathered_inputs, - gathered_logits, - local_params, - num_experts_local=num_experts_per_shard, - roll_to_expert_id=num_experts_per_shard * shard_id, - ) - - # ``output`` is [B*num_ep, S, H] (global batch after all_gather); - # psum_scatter along EP returns the local [B, S, H] slice. - output = jax.lax.psum_scatter( - output, ep_axis, scatter_dimension=0, tiled=True - ) - - # ``out_specs`` must match the returned pytree structurally, so - # always emit a real scalar for aux_loss; the outer ``__call__`` - # re-strips it to None when ``aux_loss_coeff <= 0``. - if aux_loss is None: - aux_loss = jnp.zeros((), dtype=self.dtype) - return output, aux_loss - - # ``check_rep=False`` disables shard_map's invariant that any output - # declared as ``P()`` is replicated across ``ep_axis``. We use - # ``axis_index(ep_axis)`` inside ``_ring_fn`` to compute a per-shard - # roll, which makes the body genuinely non-replicated and would - # otherwise (correctly) fail the check. The ``psum_scatter`` of the - # output already produces the right cross-shard semantics; this is - # the standard JAX escape hatch when collectives + per-shard logic - # coexist. - return shard_map( - _ring_fn, - mesh=mesh, - in_specs=in_specs, - out_specs=(P(ep_axis, None, None), P()), - check_rep=False, - )(captured) - # ------------------------------------------------------------------ # Route # ------------------------------------------------------------------ - - def _route( + # + # The router is split into two pieces so the EP path can compute + # aux_loss over global (cross-shard) statistics without re-running + # the main top-k path. ``_route_topk`` returns the per-token routing + # decisions (used by ``_global_permute``) and ``_compute_aux_loss`` + # returns the scalar load-balancing loss given the (possibly + # gathered) logits. + + def _route_topk( self, logits_2d: jnp.ndarray, expert_bias: Optional[jnp.ndarray], - ) -> Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]]: - """Run the fused router and optional aux-loss.""" + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Run the fused router top-k selection.""" sparse_probs, routing_map = fused_topk_with_score_function( logits_2d, topk=self.num_experts_per_tok, @@ -560,47 +409,73 @@ def _route( expert_bias=expert_bias, ) sparse_probs = sparse_probs.astype(self.dtype) + return sparse_probs, routing_map - aux_loss = None - if self.aux_loss_coeff > 0.0: - # The score-for-aux kernel runs independently (no data dependency - # on the main kernel), so XLA can overlap them on the GPU. - aux_scores, aux_routing_map = fused_topk_with_score_function( - logits_2d.astype(jnp.float32), - topk=self.num_experts_per_tok, - score_function=self.score_function, - compute_aux_scores=True, - ) - aux_tokens_per_expert = jnp.sum( - aux_routing_map.astype(jnp.int32), axis=0 - ) - aux_loss = fused_moe_aux_loss( - aux_scores.astype(jnp.float32), - aux_tokens_per_expert, - topk=self.num_experts_per_tok, - coeff=self.aux_loss_coeff, - ) - - return sparse_probs, routing_map, aux_loss + def _compute_aux_loss( + self, + logits_2d: jnp.ndarray, + ) -> Optional[jnp.ndarray]: + """Compute the MoE auxiliary load-balancing loss. + + The score-for-aux kernel has no data dependency on the main + routing kernel, so XLA can overlap them on the GPU. + + ``logits_2d`` should be the *full* logits tensor over the global + token batch -- under EP the caller is responsible for + :func:`jax.lax.all_gather` ing the logits before calling this so + the aux_loss formula + ``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])`` + sees the global ``T`` and the global ``tokens_per_expert``. + """ + if self.aux_loss_coeff <= 0.0: + return None + aux_scores, aux_routing_map = fused_topk_with_score_function( + logits_2d.astype(jnp.float32), + topk=self.num_experts_per_tok, + score_function=self.score_function, + compute_aux_scores=True, + ) + aux_tokens_per_expert = jnp.sum( + aux_routing_map.astype(jnp.int32), axis=0 + ) + return fused_moe_aux_loss( + aux_scores.astype(jnp.float32), + aux_tokens_per_expert, + topk=self.num_experts_per_tok, + coeff=self.aux_loss_coeff, + ) # ------------------------------------------------------------------ - # Dispatch + expert FFN + # Global permute (route -> token dispatch) # ------------------------------------------------------------------ - def _dispatch_and_expert_ffn( + def _global_permute( self, inputs_2d: jnp.ndarray, sparse_probs: jnp.ndarray, routing_map: jnp.ndarray, - params: dict, - num_experts_local: int, - roll_to_expert_id: Optional[int], - ) -> Tuple[jnp.ndarray, dict]: - """Dispatch tokens, run the three grouped GEMMs + activation, return expert outputs. - - Returns a tuple ``(expert_outputs, combine_state)`` where - ``combine_state`` carries the per-backend state needed to rebuild the - original token ordering in :meth:`_combine`. + ) -> dict: + """Dispatch tokens to the global expert axis. + + Returns a permutation-result dict suitable both for the no-EP + forward (where the same buffer feeds ``_expert_ffn`` directly) and + for the A2A-EP path (where the buffer is sliced + sent over the EP + axis before the FFN). The dict carries the per-backend opaque + state needed to invert the dispatch in :meth:`_global_combine`. + + The output dict layout is:: + + { + "backend": "pure_jax" | "triton", + "sorted_inputs": [buffer_size, hidden], + "group_sizes": [num_experts], # per-expert, + # length == E always. + "perm_state": UnfusedPermState | None, # pure_jax + "row_id_map": jnp.ndarray | None, # triton + "pad_offsets": jnp.ndarray | None, # triton + "routing_weights": jnp.ndarray | None, # pure_jax + "merging_probs": jnp.ndarray | None, # triton + } """ num_tokens = inputs_2d.shape[0] topk = self.num_experts_per_tok @@ -615,79 +490,90 @@ def _dispatch_and_expert_ffn( num_experts=self.num_experts, num_experts_per_tok=topk, align_size=self.align_size, - roll_to_expert_id=roll_to_expert_id, ) - # Slice group_sizes to just this shard's experts. When not using - # EP, ``num_experts_local == self.num_experts`` so this is a no-op. - # - # NOTE on padded buffers (``align_size > 0``): - # ``unfused_token_dispatch`` pads ``sorted_inputs`` to a static - # worst-case row count so JIT shape inference is happy. The - # returned ``group_sizes`` deliberately tracks only real + real - # alignment-padding tokens; the remaining rows are zero-input - # placeholders that ``grouped_dense`` does not need to touch. - # - # TE's ``grouped_dense`` FFI today asserts strictly - # ``sum(group_sizes) == sorted_inputs.shape[0]``. When that - # assertion is relaxed to ``>=`` (the GEMM only iterates over the - # first ``sum(group_sizes)`` rows anyway), this code works as-is. - # Folding the gap into a single expert would create a per-shard - # load imbalance and is intentionally avoided here. - group_sizes = group_sizes[:num_experts_local] - combine_state = { + return { "backend": "pure_jax", + "sorted_inputs": sorted_inputs, + "group_sizes": group_sizes, "perm_state": perm_state, "routing_weights": routing_weights, - "local_real_size": jnp.sum(group_sizes), - } - else: # "triton" - num_out_tokens = num_tokens * topk - align_size_arg = self.align_size if self.align_size > 0 else None - ( - sorted_inputs, - _permuted_probs, - row_id_map, - pad_offsets, - group_sizes, - ) = token_dispatch( - inputs_2d, - routing_map, - num_out_tokens=num_out_tokens, - probs=sparse_probs, - align_size=align_size_arg, - ) - group_sizes = group_sizes[:num_experts_local] - combine_state = { - "backend": "triton", - "row_id_map": row_id_map, - "pad_offsets": pad_offsets, - "merging_probs": sparse_probs, - "group_sizes": group_sizes, } - # ------------------------------------------------------------------ - # Expert FFN: grouped GEMMs w0, w1 + activation + w_o. - # ------------------------------------------------------------------ + # triton + num_out_tokens = num_tokens * topk + align_size_arg = self.align_size if self.align_size > 0 else None + ( + sorted_inputs, + _permuted_probs, + row_id_map, + pad_offsets, + group_sizes, + ) = token_dispatch( + inputs_2d, + routing_map, + num_out_tokens=num_out_tokens, + probs=sparse_probs, + align_size=align_size_arg, + ) + return { + "backend": "triton", + "sorted_inputs": sorted_inputs, + "group_sizes": group_sizes, + "row_id_map": row_id_map, + "pad_offsets": pad_offsets, + "merging_probs": sparse_probs, + } + + # ------------------------------------------------------------------ + # Expert FFN (three grouped_dense calls + activation) + # ------------------------------------------------------------------ + + def _expert_ffn( + self, + sorted_inputs: jnp.ndarray, + group_sizes: jnp.ndarray, + params: dict, + n_groups: int, + ) -> jnp.ndarray: + """Run the per-expert SwiGLU-style FFN over a permuted buffer. + + Parameters + ---------- + sorted_inputs : jnp.ndarray + Permuted tokens of shape ``[buffer_size, hidden]`` (rows + grouped by expert). + group_sizes : jnp.ndarray + Per-group token counts of shape ``[n_groups]``. + ``sum(group_sizes)`` must equal ``buffer_size`` (TE + ``grouped_dense`` FFI assertion at + ``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``). + params : dict + Block parameters from :meth:`_make_params`. Reads ``wi_0``, + ``wi_1``, ``wo``, and the optional bias entries. + n_groups : int + Number of expert groups. Equals ``self.num_experts`` for the + no-EP path and ``num_experts // num_ep`` for the A2A-EP path. + Used to size the per-call quantizer set so the FP8 metadata + tensors match ``group_sizes``. + + Returns + ------- + expert_outputs : jnp.ndarray + ``[buffer_size, hidden]``. + """ wi_0 = params["wi_0"] wi_1 = params["wi_1"] wo = params["wo"] # Each grouped_dense call gets its own quantizer_set with - # ``n_groups=num_experts_local``; this matches the shape of - # ``group_sizes`` passed in and keeps the quantizer FP8 meta correctly - # sized per shard. - q_set_w0 = self.generate_quantizer_set( - postfix="_w0", n_groups=num_experts_local - ) - q_set_w1 = self.generate_quantizer_set( - postfix="_w1", n_groups=num_experts_local - ) - q_set_wo = self.generate_quantizer_set( - postfix="_wo", n_groups=num_experts_local - ) - - # Cast kernels to the sort dtype when no FP8 quantization is active - # (mirrors DenseGeneral). + # n_groups matching ``group_sizes``; this keeps the FP8 meta + # tensors correctly sized in both no-EP and A2A-EP cases. + q_set_w0 = self.generate_quantizer_set(postfix="_w0", n_groups=n_groups) + q_set_w1 = self.generate_quantizer_set(postfix="_w1", n_groups=n_groups) + q_set_wo = self.generate_quantizer_set(postfix="_wo", n_groups=n_groups) + + # Cast kernels to the activation dtype when no FP8 quantization + # is active (mirrors DenseGeneral). if q_set_w0 == noop_quantizer_set: wi_0 = wi_0.astype(sorted_inputs.dtype) if q_set_w1 == noop_quantizer_set: @@ -695,9 +581,9 @@ def _dispatch_and_expert_ffn( if q_set_wo == noop_quantizer_set: wo = wo.astype(sorted_inputs.dtype) - # ``grouped_dense`` accepts per-expert bias of shape (G, N); it adds - # ``bias[i]`` to the ``group_sizes[i]`` rows belonging to expert ``i`` - # in the permuted layout. + # ``grouped_dense`` accepts per-expert bias of shape (G, N); it + # adds ``bias[i]`` to the ``group_sizes[i]`` rows belonging to + # expert ``i`` in the permuted layout. wi_0_bias = params.get("wi_0_bias") if self.use_bias else None wi_1_bias = params.get("wi_1_bias") if self.use_bias else None wo_bias = params.get("wo_bias") if self.use_bias else None @@ -730,25 +616,30 @@ def _dispatch_and_expert_ffn( bias=wo_bias, quantizer_set=q_set_wo, ) - - return expert_outputs, combine_state + return expert_outputs # ------------------------------------------------------------------ - # Combine + # Global combine (token combine -> back to [B, S, H]) # ------------------------------------------------------------------ - def _combine( + def _global_combine( self, expert_outputs: jnp.ndarray, - combine_state: dict, + perm_result: dict, batch_size: int, sequence_length: int, ) -> jnp.ndarray: - if combine_state["backend"] == "pure_jax": + """Inverse of :meth:`_global_permute`. + + Gathers per-expert outputs back into ``[batch, sequence, hidden]`` + and applies the per-token weighted sum across the top-k experts. + """ + backend = perm_result["backend"] + if backend == "pure_jax": return unfused_token_combine( expert_outputs, - combine_state["perm_state"], - combine_state["routing_weights"], + perm_result["perm_state"], + perm_result["routing_weights"], num_experts_per_tok=self.num_experts_per_tok, batch_size=batch_size, sequence_length=sequence_length, @@ -756,11 +647,297 @@ def _combine( # triton out_2d = token_combine( expert_outputs, - combine_state["row_id_map"], - merging_probs=combine_state["merging_probs"], - pad_offsets=combine_state["pad_offsets"], + perm_result["row_id_map"], + merging_probs=perm_result["merging_probs"], + pad_offsets=perm_result["pad_offsets"], ) hidden_size = out_2d.shape[-1] return out_2d.reshape(batch_size, sequence_length, hidden_size).astype( self.dtype ) + + # ------------------------------------------------------------------ + # No-EP forward + # ------------------------------------------------------------------ + + def _forward_no_ep( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + params: dict, + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """Single-shard or DP/FSDP/TP forward (no shard_map wrapper). + + DP / FSDP / TP all flow through each TE primitive's + ``custom_partitioning`` rule -- there is no cross-primitive + collective that the rules cannot express on their own, so a + ``shard_map`` is unnecessary here. + """ + batch_size, sequence_length, hidden_size = inputs.shape + inputs_2d = inputs.reshape(-1, hidden_size) + logits_2d = gate_logits.reshape(-1, self.num_experts) + + sparse_probs, routing_map = self._route_topk( + logits_2d, params.get("expert_bias") + ) + aux_loss = self._compute_aux_loss(logits_2d) + perm = self._global_permute(inputs_2d, sparse_probs, routing_map) + expert_outputs = self._expert_ffn( + perm["sorted_inputs"], + perm["group_sizes"], + params, + n_groups=self.num_experts, + ) + output = self._global_combine( + expert_outputs, perm, batch_size, sequence_length + ) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + return output, aux_loss + + # ------------------------------------------------------------------ + # A2A (ragged-all-to-all) EP forward + # ------------------------------------------------------------------ + + def _forward_a2a_ep( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + params: dict, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Wrap the body in a ``shard_map`` that runs a forward + ``ragged_all_to_all`` (A2A / A2Av) around the FFN. + + For each EP shard the wrapper: + + 1. Routes the shard's local tokens **globally** over all + ``num_experts`` experts (no roll, no local-mask -- every shard + sees the full expert axis). + 2. ``all_gather`` s its per-expert ``group_sizes`` so all shards + know the complete ``[num_ep, num_experts]`` token-count matrix. + 3. Forward ``ragged_all_to_all`` over the EP axis: each shard + sends per-expert chunks to the shard that owns those experts, + and receives chunks for its own ``num_experts // num_ep`` + local experts from every other shard. + 4. Reorders the received buffer from ``(source_shard, expert)`` + to ``(expert, source_shard)`` ordering so each local expert's + tokens are contiguous. + 5. Runs the three ``grouped_dense`` calls + activation over the + ``E_local``-group buffer. + 6. Reverses the local reorder. + 7. Reverse ``ragged_all_to_all`` over EP returns each shard's + token outputs to it. + 8. Inverts the global permute and applies the top-k weighted sum. + """ + from jax.experimental.shard_map import shard_map + + ep_axis = self.expert_parallelism_axis + if self.mesh is None: + raise ValueError( + "MoEBlock.expert_parallelism_axis is set; `mesh` must also" + " be provided so the EP shard_map can be built." + ) + mesh = self.mesh + num_ep = mesh.shape[ep_axis] + assert self.num_experts % num_ep == 0, ( + f"num_experts={self.num_experts} must be divisible by EP" + f" size={num_ep}" + ) + num_experts_local = self.num_experts // num_ep + + # Pre-compute the worst-case A2A receive buffer size (compile-time + # constant). Each shard contributes ``b_l*S*topk = B*S*topk/num_ep`` + # token-expert pairs across all experts; the worst case for one + # shard is "every global pair lands on this shard's local + # experts" -- ``num_ep * (B*S*topk/num_ep) = B*S*topk`` rows. JIT + # needs this static, so we use the global ``batch_size`` from the + # outer scope (sharded layouts don't change it). + global_batch_size, sequence_length, _hidden = inputs.shape + topk = self.num_experts_per_tok + recv_buffer_rows = global_batch_size * sequence_length * topk + + # Pack everything that crosses the shard_map boundary into a dict + # pytree. shard_map fully supports pytrees: ``in_specs`` must + # structurally match ``captured`` and we build them in lockstep + # so adding/removing an optional bias is one ``dict[name] = ...``. + captured: dict = { + "inputs": inputs, + "gate_logits": gate_logits, + "wi_0": params["wi_0"], + "wi_1": params["wi_1"], + "wo": params["wo"], + } + in_specs: dict = { + "inputs": P(ep_axis, None, None), + "gate_logits": P(ep_axis, None, None), + "wi_0": P(ep_axis, None, None), + "wi_1": P(ep_axis, None, None), + "wo": P(ep_axis, None, None), + } + if "expert_bias" in params: + captured["expert_bias"] = params["expert_bias"] + in_specs["expert_bias"] = P(ep_axis) + if "wi_0_bias" in params: + for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): + captured[name] = params[name] + in_specs[name] = P(ep_axis, None) + + def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: + shard_id = jax.lax.axis_index(ep_axis) + + # -- Stage 1: per-shard route + global permute over all E -- + # Inside the shard_map body each input has its EP axis already + # consumed, so ``local_inputs.shape == [B/num_ep, S, H]``. + local_inputs = local["inputs"] + local_logits = local["gate_logits"] + local_b, local_s, local_h = local_inputs.shape + inputs_2d = local_inputs.reshape(-1, local_h) + logits_2d = local_logits.reshape(-1, self.num_experts) + + # The router operates over the full expert axis, so the + # EP-sharded ``expert_bias`` (in_spec ``P(ep_axis)``) must be + # all-gathered before being passed in. + if "expert_bias" in local: + full_expert_bias = jax.lax.all_gather( + local["expert_bias"], axis_name=ep_axis, tiled=True + ) + else: + full_expert_bias = None + sparse_probs, routing_map = self._route_topk( + logits_2d, full_expert_bias + ) + + # aux_loss must see the global token batch and the global + # tokens_per_expert: its formula ``E*coeff/(k*T^2) * sum_i( + # sum_t(probs[t,i]) * tokens[i])`` is not shard-decomposable + # (the sum_t * tokens product is data-dependent across + # shards). Cheapest fix: gather logits along the EP axis and + # run the aux-loss kernel on the global tensor. The aux + # branch has no data dependency on the main routing path so + # XLA can overlap the two on the GPU. + if self.aux_loss_coeff > 0.0: + global_logits_2d = jax.lax.all_gather( + logits_2d, axis_name=ep_axis, axis=0, tiled=True + ) + aux_loss = self._compute_aux_loss(global_logits_2d) + else: + aux_loss = None + + perm = self._global_permute(inputs_2d, sparse_probs, routing_map) + global_group_sizes = perm["group_sizes"] # [E] + + # -- Stage 2: gather per-expert counts across the EP axis -- + all_shards_tokens_per_expert = jax.lax.all_gather( + global_group_sizes[None, :], + axis_name=ep_axis, + axis=0, + tiled=True, + ) # [num_ep, num_experts] + + # -- Stage 3: forward ragged_all_to_all over EP -- + in_off, send_sz, out_off, recv_sz = compute_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep + ) + recv_buf = jnp.zeros( + (recv_buffer_rows, local_h), + dtype=perm["sorted_inputs"].dtype, + ) + x_recv = jax.lax.ragged_all_to_all( + perm["sorted_inputs"], + recv_buf, + in_off, + send_sz, + out_off, + recv_sz, + axis_name=ep_axis, + ) + + # -- Stage 4: local permute (source_shard, expert) -> (expert, shard) + sorted_x, local_group_sizes, local_perm_state = ( + local_permute_after_a2a( + x_recv, + all_shards_tokens_per_expert, + shard_id, + num_ep, + ) + ) + + # -- Stage 5: per-expert FFN (E_local groups) -- + local_params: dict = { + "wi_0": local["wi_0"], + "wi_1": local["wi_1"], + "wo": local["wo"], + } + if "wi_0_bias" in local: + local_params["wi_0_bias"] = local["wi_0_bias"] + local_params["wi_1_bias"] = local["wi_1_bias"] + local_params["wo_bias"] = local["wo_bias"] + expert_outputs = self._expert_ffn( + sorted_x, + local_group_sizes, + local_params, + n_groups=num_experts_local, + ) + + # -- Stage 6: invert local permute -- + x_send_back = local_unpermute_before_a2a( + expert_outputs, local_perm_state + ) + + # -- Stage 7: reverse ragged_all_to_all over EP -- + in_off_r, send_sz_r, out_off_r, recv_sz_r = ( + compute_reverse_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep + ) + ) + send_back_buf = jnp.zeros_like(perm["sorted_inputs"]) + y_back = jax.lax.ragged_all_to_all( + x_send_back, + send_back_buf, + in_off_r, + send_sz_r, + out_off_r, + recv_sz_r, + axis_name=ep_axis, + ) + + # -- Stage 8: invert global permute, weighted sum over top-k -- + output = self._global_combine( + y_back, perm, batch_size=local_b, sequence_length=local_s + ) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + + # ``out_specs`` must match the returned pytree structurally, + # so always emit a real scalar for aux_loss; the outer + # ``__call__`` re-strips it to None when aux_loss_coeff <= 0. + if aux_loss is None: + aux_loss = jnp.zeros((), dtype=self.dtype) + return output, aux_loss + + # ``check_rep=False`` disables shard_map's invariant that any + # output declared as ``P()`` is replicated across ``ep_axis``. + # We use ``axis_index(ep_axis)`` inside ``_a2a_fn`` so the body + # is genuinely non-replicated, which would otherwise (correctly) + # fail the check. ``ragged_all_to_all`` already produces the + # right cross-shard semantics; this is the standard JAX escape + # hatch when collectives + per-shard logic coexist. + return shard_map( + _a2a_fn, + mesh=mesh, + in_specs=in_specs, + out_specs=(P(ep_axis, None, None), P()), + check_rep=False, + )(captured) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 1a492ba186..f4599a7b8f 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -52,6 +52,11 @@ "unfused_token_dispatch", "unfused_token_combine", "UnfusedPermState", + # Ragged-all-to-all expert-parallelism helpers + "compute_ragged_all_to_all_params", + "compute_reverse_ragged_all_to_all_params", + "local_permute_after_a2a", + "local_unpermute_before_a2a", ] @@ -989,3 +994,334 @@ def unfused_token_combine( reshaped_weights, ) return output.reshape(batch_size, sequence_length, hidden_size) + + +# ============================================================================= +# Ragged-all-to-all expert-parallelism helpers +# ============================================================================= +# +# These helpers support the ragged-all-to-all (A2A / A2Av) EP strategy used by +# :class:`transformer_engine.jax.flax.MoEBlock`. The forward EP path looks +# like:: +# +# route -> global_permute -> AG(group_sizes, ep) +# -> ragged_all_to_all(fwd, ep) +# -> local_permute_after_a2a +# -> grouped_dense x3 + activation +# -> local_unpermute_before_a2a +# -> ragged_all_to_all(reverse, ep) +# -> global_combine +# +# The two ``compute_*_ragged_all_to_all_params`` functions translate +# ``all_shards_tokens_per_expert`` (an EP-axis ``all_gather`` of each shard's +# global ``group_sizes``) into the four ``ragged_all_to_all`` arguments +# (``input_offsets``, ``send_sizes``, ``output_offsets``, ``recv_sizes``). +# ``shard_id`` may be a traced value (e.g. from :func:`jax.lax.axis_index`), +# which is why every slice into ``all_shards_tokens_per_expert`` uses +# :func:`jax.lax.dynamic_slice`. +# +# These functions are pure JAX (no MaxText / TE dependencies) and equivalent +# to :func:`maxtext.layers.te_permutation.compute_ragged_all_to_all_params` +# / :func:`compute_reverse_ragged_all_to_all_params`. + + +def compute_ragged_all_to_all_params( + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Forward-direction ragged_all_to_all parameters. + + Computes the four index/size arrays that :func:`jax.lax.ragged_all_to_all` + consumes for the **forward** EP shuffle, where each shard sends its + expert-grouped tokens to the shard that owns those experts. + + Parameters + ---------- + all_shards_tokens_per_expert : jnp.ndarray + Per-shard, per-expert token counts gathered across the EP axis. Shape + ``[num_expert_shards, num_experts]`` and integer dtype. + shard_id : jnp.ndarray + Index of the current shard along the EP axis (typically + :func:`jax.lax.axis_index` of the EP axis). Must be a 0-d integer. + num_expert_shards : int + Static EP-axis size. Must match + ``all_shards_tokens_per_expert.shape[0]``. + + Returns + ------- + input_offsets : jnp.ndarray + Shape ``[num_expert_shards]``. Cumulative ``send_sizes`` (with a + leading 0) -- where in the local source buffer each destination + shard's chunk begins. + send_sizes : jnp.ndarray + Shape ``[num_expert_shards]``. ``send_sizes[i]`` is the number of + tokens this shard sends to shard ``i`` (= the sum of token counts + for the experts owned by shard ``i``). + output_offsets : jnp.ndarray + Shape ``[num_expert_shards]``. ``output_offsets[i]`` is the row in + shard ``i``'s receive buffer where this shard's contribution should + land. Sender-side semantics, per :func:`jax.lax.ragged_all_to_all`. + recv_sizes : jnp.ndarray + Shape ``[num_expert_shards]``. ``recv_sizes[i]`` is the number of + tokens shard ``i`` sends to this shard. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert num_experts % num_expert_shards == 0, ( + f"num_experts={num_experts} must be divisible by num_expert_shards" + f"={num_expert_shards}" + ) + local_expert_size = num_experts // num_expert_shards + + # This shard's row of the gathered table, reshaped so axis 0 indexes the + # destination shard and axis 1 indexes its local experts. + local_tokens_per_expert = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(shard_id, 0), + slice_sizes=(1, num_experts), + ).squeeze(0) + local_reshaped = local_tokens_per_expert.reshape( + num_expert_shards, local_expert_size + ) + + # send_sizes[i] = sum of token counts for shard i's experts in our buffer. + send_sizes = jnp.sum(local_reshaped, axis=1) + input_offsets = jnp.concatenate( + [ + jnp.array([0], dtype=send_sizes.dtype), + jnp.cumsum(send_sizes)[:-1], + ] + ) + + # recv_sizes[i] = how many tokens shard i sends to this shard, i.e. the + # sum across our local-expert columns of shard i's row. + local_expert_start = shard_id * local_expert_size + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + recv_sizes = jnp.sum(local_expert_columns, axis=1) + + # output_offsets uses sender-side semantics for ragged_all_to_all: + # output_offsets[j] = row in shard j's buffer where THIS shard's chunk + # should be placed. That's the cumulative sum (over source shards 0..j-1) + # of how many tokens those earlier source shards already sent to shard j. + sends_to_target = jnp.sum( + all_shards_tokens_per_expert.reshape( + num_expert_shards, num_expert_shards, local_expert_size + ), + axis=2, + ) # [src_shard, dst_shard] + zero_row = jnp.zeros((1, num_expert_shards), dtype=sends_to_target.dtype) + cumulated = jnp.cumsum( + jnp.concatenate([zero_row, sends_to_target], axis=0), + axis=0, + dtype=sends_to_target.dtype, + ) # [src_shard + 1, dst_shard]; row r = total sent by sources 0..r-1 + output_offsets = jax.lax.dynamic_slice( + cumulated, + start_indices=(shard_id, 0), + slice_sizes=(1, num_expert_shards), + ).squeeze(0) + + return input_offsets, send_sizes, output_offsets, recv_sizes + + +def compute_reverse_ragged_all_to_all_params( + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Reverse-direction ragged_all_to_all parameters. + + Mirror of :func:`compute_ragged_all_to_all_params` for the **reverse** + EP shuffle that returns expert outputs to their source shards. The + sender / receiver roles are swapped: what we received in the forward + shuffle we now send back, and vice versa. + + Parameters and shapes are identical to + :func:`compute_ragged_all_to_all_params`. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert num_experts % num_expert_shards == 0, ( + f"num_experts={num_experts} must be divisible by num_expert_shards" + f"={num_expert_shards}" + ) + local_expert_size = num_experts // num_expert_shards + + local_expert_start = shard_id * local_expert_size + + # In reverse, what we received becomes what we send. send_sizes[i] is how + # many tokens we send back to source shard i (= what shard i originally + # sent us, summed across our local experts). + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + send_sizes = jnp.sum(local_expert_columns, axis=1) + input_offsets = jnp.concatenate( + [ + jnp.array([0], dtype=send_sizes.dtype), + jnp.cumsum(send_sizes)[:-1], + ] + ) + + # recv_sizes[i] = how many tokens we receive back from shard i (= what + # we originally sent to shard i in the forward). + local_tokens_per_expert = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(shard_id, 0), + slice_sizes=(1, num_experts), + ).squeeze(0) + local_reshaped = local_tokens_per_expert.reshape( + num_expert_shards, local_expert_size + ) + recv_sizes = jnp.sum(local_reshaped, axis=1) + + # output_offsets: the reverse sends-to-target matrix is the transpose of + # the forward one (row i = what shard i sends in reverse = what shard i + # received in forward). Cumsum down source-shard axis, then index our row. + fwd_sends_to = jnp.sum( + all_shards_tokens_per_expert.reshape( + num_expert_shards, num_expert_shards, local_expert_size + ), + axis=2, + ) # forward: [src, dst] + rev_sends_to = jnp.transpose(fwd_sends_to) # reverse: [src, dst] + zero_row = jnp.zeros((1, num_expert_shards), dtype=rev_sends_to.dtype) + rev_cumulated = jnp.cumsum( + jnp.concatenate([zero_row, rev_sends_to], axis=0), + axis=0, + dtype=rev_sends_to.dtype, + ) + output_offsets = jax.lax.dynamic_slice( + rev_cumulated, + start_indices=(shard_id, 0), + slice_sizes=(1, num_expert_shards), + ).squeeze(0) + + return input_offsets, send_sizes, output_offsets, recv_sizes + + +# ----------------------------------------------------------------------------- +# Local permute / unpermute +# ----------------------------------------------------------------------------- +# +# After the forward ragged_all_to_all the receive buffer is laid out as +# ``[from_shard_0_chunk | from_shard_1_chunk | ... ]`` and within each chunk +# tokens are sorted by local-expert id. To feed ``grouped_dense`` we want +# ``[expert_0_block | expert_1_block | ... ]`` where each expert's block +# contains tokens from every source shard. ``local_permute_after_a2a`` +# performs that reorder; ``local_unpermute_before_a2a`` undoes it before the +# reverse ragged_all_to_all. +# +# Implementation uses :func:`sort_chunks_by_index`, which is Triton-backed +# (see ``transformer_engine.jax.triton_extensions.permutation``) and has a +# paired custom-VJP backward. There is no pure-JAX alternative here -- the +# global :func:`unfused_token_dispatch` / :func:`token_dispatch` choice is +# unaffected by this; only the (small) post-A2A chunk reorder uses Triton +# unconditionally. + + +def local_permute_after_a2a( + x_recv: jnp.ndarray, + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, dict]: + """Reorder tokens received via ragged_all_to_all so each local expert's + tokens are contiguous. + + This is the EP-side complement to the global :func:`token_dispatch` / + :func:`unfused_token_dispatch`. Internally uses + :func:`sort_chunks_by_index` (Triton-backed) for both the forward sort + and -- via :func:`local_unpermute_before_a2a` -- the inverse. + + Parameters + ---------- + x_recv : jnp.ndarray + Output of the forward ``ragged_all_to_all`` of shape + ``[buffer_size, hidden_size]``. Layout: source-shard major, then + local-expert id within each source chunk. + all_shards_tokens_per_expert : jnp.ndarray + Per-shard, per-expert token counts of shape + ``[num_expert_shards, num_experts]``. + shard_id : jnp.ndarray + Current EP shard index (typically a traced + :func:`jax.lax.axis_index`). + num_expert_shards : int + Static EP-axis size. + + Returns + ------- + sorted_x : jnp.ndarray + Tokens reordered into expert-major layout. Same shape as ``x_recv``. + local_group_sizes : jnp.ndarray + Per-local-expert token counts of shape ``[local_expert_size]``. + state : dict + Opaque state for :func:`local_unpermute_before_a2a`. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert num_experts % num_expert_shards == 0, ( + f"num_experts={num_experts} must be divisible by num_expert_shards" + f"={num_expert_shards}" + ) + local_expert_size = num_experts // num_expert_shards + local_expert_start = shard_id * local_expert_size + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + + # Flat sizes in source-major order, matching the receive buffer layout: + # [(s0,e0), (s0,e1), ..., (s1,e0), (s1,e1), ...] + split_sizes = local_expert_columns.reshape(-1) + + # Permutation that maps source-major -> expert-major: + # original index = s * E_local + e + # target index = e * num_shards + s + indices_matrix = jnp.arange( + num_expert_shards * local_expert_size, dtype=jnp.int32 + ).reshape(num_expert_shards, local_expert_size) + sorted_chunk_indices = indices_matrix.T.reshape(-1) + + sorted_x, _ = sort_chunks_by_index(x_recv, split_sizes, sorted_chunk_indices) + sorted_split_sizes = split_sizes[sorted_chunk_indices] + inverse_chunk_indices = jnp.argsort(sorted_chunk_indices) + local_group_sizes = jnp.sum(local_expert_columns, axis=0) + state = { + "sorted_split_sizes": sorted_split_sizes, + "inverse_chunk_indices": inverse_chunk_indices, + } + return sorted_x, local_group_sizes, state + + +def local_unpermute_before_a2a( + expert_outputs: jnp.ndarray, + state: dict, +) -> jnp.ndarray: + """Inverse of :func:`local_permute_after_a2a`. + + Parameters + ---------- + expert_outputs : jnp.ndarray + Output of the local expert FFN of shape ``[buffer_size, hidden_size]``, + in expert-major layout. + state : dict + Opaque state returned by :func:`local_permute_after_a2a`. + + Returns + ------- + unsorted_x : jnp.ndarray + Tokens reordered back into source-shard-major layout, ready for the + reverse ``ragged_all_to_all``. Same shape as ``expert_outputs``. + """ + out, _ = sort_chunks_by_index( + expert_outputs, + state["sorted_split_sizes"], + state["inverse_chunk_indices"], + ) + return out From 626aae4dccf20ec78a21082ca5b58aa50dcd88ad Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 30 Apr 2026 14:08:55 -0700 Subject: [PATCH 6/6] fix test_distributed issues with unpopulated LogicallyPartition pytree and single device initial params in the MoEBlock. Tests should pass now Signed-off-by: tdophung --- tests/jax/test_distributed_moe_block.py | 29 ++++++++++++++++++++++++- transformer_engine/jax/flax/moe.py | 2 +- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py index 9d9e57140f..1c7b99cda4 100644 --- a/tests/jax/test_distributed_moe_block.py +++ b/tests/jax/test_distributed_moe_block.py @@ -112,7 +112,34 @@ def loss_fn(block, variables, x): with mesh, autocast(enabled=False, mesh_resource=MeshResource(fsdp_resource="fsdp")): with nn.logical_axis_rules(logical_axis_rules): - sharded_variables = sharded_block.init(init_key, inputs) + # ``MoEBlock`` registers params via ``with_logical_partitioning`` + # which only attaches LogicallyPartitioned metadata; the + # underlying jax.Array stays single-device unless ``init`` + # is run inside ``jax.jit`` with ``out_shardings``. Use the + # canonical Flax-Linen pattern (mirrors + # ``examples/jax/encoder/test_model_parallel_encoder.py``): + # 1. ``jax.eval_shape`` to trace abstract variables (keeps + # the LogicallyPartitioned wrappers; only the inner + # arrays become ShapeDtypeStruct); + # 2. ``nn.get_partition_spec`` to extract a tree of logical + # PartitionSpecs from those wrappers (treats + # LogicallyPartitioned as a leaf); + # 3. ``nn.logical_to_mesh_sharding`` to resolve those + # logical specs to NamedShardings via the active rules; + # 4. ``jax.jit(init, out_shardings=...)`` to actually + # place the params on-device with those shardings. + abstract_variables = jax.eval_shape( + sharded_block.init, init_key, inputs + ) + logical_partition_spec = nn.get_partition_spec( + abstract_variables + ) + out_shardings = nn.logical_to_mesh_sharding( + logical_partition_spec, mesh, logical_axis_rules + ) + sharded_variables = jax.jit( + sharded_block.init, out_shardings=out_shardings + )(init_key, inputs) (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = ( jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( sharded_block, sharded_variables, inputs diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 690d804e38..050cbe84d0 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -937,7 +937,7 @@ def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: return shard_map( _a2a_fn, mesh=mesh, - in_specs=in_specs, + in_specs=(in_specs,), out_specs=(P(ep_axis, None, None), P()), check_rep=False, )(captured)