Add support for late-interaction-kernels (LIK)#412
Open
tonywu71 wants to merge 5 commits into
Open
Conversation
ManuelFay
reviewed
May 22, 2026
Collaborator
ManuelFay
left a comment
There was a problem hiding this comment.
this is super cool !
I think we need to run a real training with this kernel to see before merging what the gains are: how much can we increase batch size by, how much more speed do we gain?
this would also be a good opportunity to reverify some of the training scripts and the doc here to make sure it s starightforward to do (i am not so sure).
4a733b6 to
1223fbb
Compare
- Add maxsim dispatcher (maxsim_inbatch, maxsim_kd) with a pure-torch einsum reference and a lazily-imported LIK backend - Mirror PyLate's design: [lik] extra, COLPALI_SCORES_BACKEND env var (auto/torch/lik) read per call, LIKUnsupportedError sentinel - Route score_multi_vector and the five ColBERT losses through the dispatcher (negative-doc losses via LIK's kd_layout) - Add CPU dispatch tests plus CUDA parity and training-smoke tests - Fix transformers-5.x trainer breakage (_get_train_sampler signature, single-dataset compute_loss prefixes) - Document the extra and the backend toggle in README and CHANGELOG
- Add bench_train.py: runs training steps with the maxsim dispatcher
instrumented (per-call forward peak and bytes held for backward),
then replays each recorded shape on an isolated graph to bracket
the op's backward exactly
- Add SkyPilot sweep over B in {16..128} x {auto, torch}, fresh
process per cell so an OOM is isolated
- Add summarizer emitting the markdown table and the log-log plot
- Add train-subset loaders so runs skip the full 52 GB train set
8 cells from a 1x H100 run (LIK 0.4.1, ColQwen2 + LoRA): per-op forward/held/backward VRAM plus whole-step peaks; the vanilla B=128 cell records the fragmentation OOM message.
Keep the final tree lean: the harness and results stay reachable at the two prior commits, referenced from the PR description.
cacf178 to
1d3363c
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds an optional
[lik]extra that routes ColBERT MaxSim scoring throughlate-interaction-kernels(LIK), a fused Triton kernel, on CUDA Ampere+ / Apple Silicon, with a transparent pure-torch fallback everywhere else. It is opt-in and feature-flagged (COLPALI_SCORES_BACKENDselectsauto/torch/lik), with no change to the public API or training semantics: the kernel and the torch reference return the same scores and the same loss.On the MaxSim operation itself the kernel's speedup is unambiguous: isolating the loss head on ColPali-like shapes (no encoder), LIK runs the forward+backward up to 2.5–4.3× faster at large batch×negatives, with the win growing with
B × n_neg(LIK 0.4.1 benchmarks). In ColVision training, however, MaxSim is a few milliseconds inside a step dominated by the 2B-parameter model forward/backward, so the op speedup dilutes to per-step parity end-to-end. What survives the dilution is the memory win: measured at the op level, vanilla MaxSim costs 7.8 GiB of VRAM at B=128 where LIK costs 62 MiB (129×), and that B²-growing term is exactly what caps the trainable batch size. Removing it doubles the batch on an 80 GB H100 (vanilla OOMs at B=128, LIK trains it).What this PR adds
colpali_engine/utils/maxsim.py: the dispatcher (maxsim_inbatch,maxsim_kd) selecting between the LIK backend and the torcheinsum + amax + sumreference perCOLPALI_SCORES_BACKEND.colpali_engine/utils/_lik_backend.py: the lazily-imported LIK implementations, input validation, and theLIKUnsupportedErrorsentinel.score_multi_vector, the three in-batch ColBERT losses (ColbertLoss,ColbertPairwiseCELoss,ColbertSigmoidLoss), and the two negative-doc losses (ColbertNegativeCELoss,ColbertPairwiseNegativeCELoss, via LIK'skd_layout) through the dispatcher.pyproject.toml: optional extralik = ["late-interaction-kernels>=0.4.1,<0.5.0"]; README section documenting the extra and the env var._get_train_samplersignature; single-datasetcompute_lossprefixes).The benchmarking harnesses used for the numbers below were added and then removed within this PR's history, keeping the final tree lean: the op-level VRAM harness and its results live at
1717e37, the original batch-size sweep at2749bd5(a pre-rebase commit GitHub keeps accessible).Design
maxsim_inbatch(Q, D)handles the in-batch[B, Lq, d] x [B, Ld, d]grid (used byscore_multi_vectorand the in-batch losses);maxsim_kd(Q, D)handles the per-query candidate layout[B, N, Ld, d](negative-doc losses). The LIK implementations live in a lazily-imported_lik_backendmodule that validates each call (CUDA Ampere+ or MPS, embedding dim above the kernel's tile floor, matching devices) and raises aLIKUnsupportedErrorsentinel when the kernel cannot run; real kernel errors always propagate. Both paths treat padded tokens as exactly-zero embeddings rather than an explicit mask; ColQwen2 already zeroes padded positions through the attention mask, so the scores match.The design deliberately matches PyLate's integration (lightonai/pylate#222), so using both libraries means one mental model: the extra is
[lik], the backend module split is the same, andCOLPALI_SCORES_BACKEND(read per call) mirrorsPYLATE_SCORES_BACKENDwith the same three values:auto(default) uses LIK when eligible and silently falls back to torch,torchforces the reference, andlikis strict, raisingLIKUnsupportedErrorinstead of falling back.Results
The kernel itself is much faster than the einsum it replaces. LIK's 0.4.1 benchmarks isolate the loss head on ColPali-like shapes (
Lq=32, Ld=1030, no encoder, forward+backward at matched numerics): the speedup climbs 1.13× → 4.31× asB × n_neggrows (2.50× at B256×n8, 4.31× at B256×n16), with ~25–30% lower peak memory on the head. In LIK's own words, this is "the throughput the encoder hides" in end-to-end training.End-to-end, the model forward dilutes that speedup to parity. A ColQwen2 training step is dominated by the 2B-parameter doc/query towers; MaxSim is a few milliseconds of a >1 s step. We measure per-step parity (B=64: 7.19 vs 7.23 samples/s), and LIK's own end-to-end ColQwen2 table shows the same 0.97–1.02×. The dilution is mechanical, not a kernel property: on a 17M encoder, where MaxSim is a bigger slice of the step, the same kernel shows up as a 1.1–1.3× end-to-end speedup.
What survives at ColQwen2 scale is the memory win, measured at the op level. We instrumented the dispatcher during real ColQwen2 training steps, then replayed each recorded shape on an isolated graph where the op's backward can be measured exactly (the replayed forward numbers match the in-train ones to the MiB). The VRAM attributable to MaxSim:
The score grid is fp32 in practice (autocast computes the embedding L2-norm in fp32 and the division promotes, so the loss runs on fp32 embeddings). At B=128 the
[B,B,Lq,Ld]tensor is 2.4 GiB, held from the op's forward until its backward, where the op spikes another 2.25× that (the grid's gradient plus theamaxscatter): 5.4 GiB. LIK holds only the[B,B]output and its backward allocates only the input gradients (dominated bygrad_D), so its footprint grows linearly in B instead of quadratically: 62 MiB total at B=128, a 129× reduction.That op footprint is what sets the batch-size ceiling. Sweeping
per_device_train_batch_sizeuntil OOM (ColQwen2 + LoRA, grad-checkpointing, bf16, 80 GB H100): whole-step peak allocated VRAM is identical while both fit, then splits at B=128.Note
Vanilla and LIK look identical in the VRAM table up to B=64 because the score tensors are freed before the peak: the op's backward runs first, then the model backward where the peak lives. At B=128 vanilla dies not because it uses more peak memory, but because its score grid needs multi-GiB contiguous blocks that memory fragmentation makes impossible to satisfy (the observed OOM is a 1.81 GiB request failing while 25 GiB sit reserved but unallocated). LIK's 62 MiB fits in whatever scraps remain.
Vanilla maxes out at B=64, LIK at B=128 (2× headroom). At B=256 both paths OOM: pushing 256 pages of ~768 visual tokens each through the 2B doc tower is the limit, regardless of the score tensor. The first steps pay a one-time Triton autotune warmup that amortizes over a full run. The loss matches the torch reference within bf16 noise.
Full sweep table
1× H100,
vidore/colqwen2-base,ColbertPairwiseCELoss, grad-checkpointing on, LIK0.4.1. Fresh process per (B, backend) so an OOM is isolated.Throughput is from 4-step runs, so it is autotune-warmup-affected (LIK looks slower at B=16 only because warmup dominates 4 steps). The point that matters: LIK runs B=128 at 7.98 samples/s, which vanilla cannot reach.
Reproduce
The op-level VRAM harness and its result JSONs were added and removed within this PR's history; check out
1717e37to get both. The harness wrapsmaxsim_inbatchduring training to record the forward peak and the bytes held for backward, then replays each recorded shape on an isolated graph to bracket the op's backward exactly (a grad hook cannot bracket it in-train: it fires as a pre-hook of the producing node, after the whole doc-tower backward).The whole-step batch-size sweep (B up to 512) ran on an earlier harness iteration at
2749bd5(sky_batch_sweep.yaml+summarize_sweep.py, pre-rebase commit kept accessible by GitHub). The CUDA test runner (scripts/sky_test_lik.yaml) lives at7ae3402; the slow suite itself stays in-tree (pytest -m slow tests/utils/test_maxsim_cuda.pyon a CUDA Ampere+ host).Force a single run onto a backend with
COLPALI_SCORES_BACKEND=auto|torch|lik(likerrors instead of silently falling back).Next steps
late-interaction-kernelsrepository.colpali-enginerelease once merged (0.3.17if nothing else lands in between) so the[lik]extra is installable from PyPI.