Skip to content

Add all-gather + matmul ring tutorial#124

Open
yongweiy wants to merge 1 commit into
aws-neuron:mainfrom
yongweiy:allgather-matmul-ring
Open

Add all-gather + matmul ring tutorial#124
yongweiy wants to merge 1 commit into
aws-neuron:mainfrom
yongweiy:allgather-matmul-ring

Conversation

@yongweiy
Copy link
Copy Markdown

@yongweiy yongweiy commented May 8, 2026

Summary

Adds a new tutorial under src/nki_samples/tutorials/allgather_matmul_ring/ demonstrating a fused all-gather + matmul along a ring of TP ranks, using nki.collectives.collective_permute_implicit (CPI) to overlap communication with compute.

Each ring step, a rank computes a local matmul against the LHS fragment currently in its ring buffer, then passes the fragment on to the next rank and receives the previous rank's fragment — the scheduler places the matmul of step i and the CPI of step i on disjoint engines so they run concurrently. The matmul is row-parallel (LHS row-sharded, RHS column-sharded). After RANK_N ring steps every rank has computed one (M_LOCAL, N_LOCAL) slot of the fully-gathered output for every source rank.

Follows the existing tutorial layout:

  • allgather_matmul_ring_nki_kernels.py — NKI kernel (guarded with NKI_EXAMPLE_AGMM_RING_* markers).
  • allgather_matmul_ring_torch.py — PyTorch/XLA runner: xmp.spawn across TP ranks, builds deterministic LHS/RHS shards, validates each rank's output against a reference matmul computed on the host.

Test plan

  • trn2, TP=16 LNC=2 (4 devices): NEURON_CC_FLAGS="--lnc=2" NEURON_LOGICAL_NC_CONFIG=2 NEURONCORE_NUM_DEVICES=16 python allgather_matmul_ring_torch.py — all 16 ranks PASS with rel_err ≈ 0.0022.
  • trn2, TP=4 LNC=2 (single device): same script, NEURONCORE_NUM_DEVICES=4 — PASS.

Other TP sizes may fail if the replica group does not map to a valid CPI ring topology on the hardware; documented in the tutorial's module docstring.

New tutorial demonstrating a fused all-gather + matmul along a ring of
TP ranks, using `nki.collectives.collective_permute_implicit` to
overlap communication with compute.

Adds `allgather_matmul_ring/`:
  - `allgather_matmul_ring_nki_kernels.py` — the ring kernel
  - `allgather_matmul_ring_torch.py` — torch_xla runner + reference check

Validated on trn2 at TP=16 LNC=2 with rel_err < 0.003.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant