Add all-gather + matmul ring tutorial#124
Open
yongweiy wants to merge 1 commit into
Open
Conversation
New tutorial demonstrating a fused all-gather + matmul along a ring of TP ranks, using `nki.collectives.collective_permute_implicit` to overlap communication with compute. Adds `allgather_matmul_ring/`: - `allgather_matmul_ring_nki_kernels.py` — the ring kernel - `allgather_matmul_ring_torch.py` — torch_xla runner + reference check Validated on trn2 at TP=16 LNC=2 with rel_err < 0.003.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a new tutorial under
src/nki_samples/tutorials/allgather_matmul_ring/demonstrating a fused all-gather + matmul along a ring of TP ranks, usingnki.collectives.collective_permute_implicit(CPI) to overlap communication with compute.Each ring step, a rank computes a local matmul against the LHS fragment currently in its ring buffer, then passes the fragment on to the next rank and receives the previous rank's fragment — the scheduler places the matmul of step i and the CPI of step i on disjoint engines so they run concurrently. The matmul is row-parallel (LHS row-sharded, RHS column-sharded). After
RANK_Nring steps every rank has computed one(M_LOCAL, N_LOCAL)slot of the fully-gathered output for every source rank.Follows the existing tutorial layout:
allgather_matmul_ring_nki_kernels.py— NKI kernel (guarded withNKI_EXAMPLE_AGMM_RING_*markers).allgather_matmul_ring_torch.py— PyTorch/XLA runner:xmp.spawnacross TP ranks, builds deterministic LHS/RHS shards, validates each rank's output against a reference matmul computed on the host.Test plan
NEURON_CC_FLAGS="--lnc=2" NEURON_LOGICAL_NC_CONFIG=2 NEURONCORE_NUM_DEVICES=16 python allgather_matmul_ring_torch.py— all 16 ranks PASS withrel_err ≈ 0.0022.NEURONCORE_NUM_DEVICES=4— PASS.Other TP sizes may fail if the replica group does not map to a valid CPI ring topology on the hardware; documented in the tutorial's module docstring.