Skip to content

Add support for late-interaction-kernels (LIK)#412

Open
tonywu71 wants to merge 5 commits into
illuin-tech:mainfrom
tonywu71:add-support-for-late-interaction-kernels
Open

Add support for late-interaction-kernels (LIK)#412
tonywu71 wants to merge 5 commits into
illuin-tech:mainfrom
tonywu71:add-support-for-late-interaction-kernels

Conversation

@tonywu71
Copy link
Copy Markdown
Contributor

@tonywu71 tonywu71 commented May 22, 2026

Summary

Adds an optional [lik] extra that routes ColBERT MaxSim scoring through late-interaction-kernels (LIK), a fused Triton kernel, on CUDA Ampere+ / Apple Silicon, with a transparent pure-torch fallback everywhere else. It is opt-in and feature-flagged (COLPALI_SCORES_BACKEND selects auto/torch/lik), with no change to the public API or training semantics: the kernel and the torch reference return the same scores and the same loss.

On the MaxSim operation itself the kernel's speedup is unambiguous: isolating the loss head on ColPali-like shapes (no encoder), LIK runs the forward+backward up to 2.5–4.3× faster at large batch×negatives, with the win growing with B × n_neg (LIK 0.4.1 benchmarks). In ColVision training, however, MaxSim is a few milliseconds inside a step dominated by the 2B-parameter model forward/backward, so the op speedup dilutes to per-step parity end-to-end. What survives the dilution is the memory win: measured at the op level, vanilla MaxSim costs 7.8 GiB of VRAM at B=128 where LIK costs 62 MiB (129×), and that B²-growing term is exactly what caps the trainable batch size. Removing it doubles the batch on an 80 GB H100 (vanilla OOMs at B=128, LIK trains it).

What this PR adds

  • colpali_engine/utils/maxsim.py: the dispatcher (maxsim_inbatch, maxsim_kd) selecting between the LIK backend and the torch einsum + amax + sum reference per COLPALI_SCORES_BACKEND.
  • colpali_engine/utils/_lik_backend.py: the lazily-imported LIK implementations, input validation, and the LIKUnsupportedError sentinel.
  • Routes score_multi_vector, the three in-batch ColBERT losses (ColbertLoss, ColbertPairwiseCELoss, ColbertSigmoidLoss), and the two negative-doc losses (ColbertNegativeCELoss, ColbertPairwiseNegativeCELoss, via LIK's kd_layout) through the dispatcher.
  • pyproject.toml: optional extra lik = ["late-interaction-kernels>=0.4.1,<0.5.0"]; README section documenting the extra and the env var.
  • CPU dispatch tests plus CUDA parity and training-smoke tests.
  • Two transformers-5.x trainer fixes hit while validating the path (_get_train_sampler signature; single-dataset compute_loss prefixes).

The benchmarking harnesses used for the numbers below were added and then removed within this PR's history, keeping the final tree lean: the op-level VRAM harness and its results live at 1717e37, the original batch-size sweep at 2749bd5 (a pre-rebase commit GitHub keeps accessible).

Design

maxsim_inbatch(Q, D) handles the in-batch [B, Lq, d] x [B, Ld, d] grid (used by score_multi_vector and the in-batch losses); maxsim_kd(Q, D) handles the per-query candidate layout [B, N, Ld, d] (negative-doc losses). The LIK implementations live in a lazily-imported _lik_backend module that validates each call (CUDA Ampere+ or MPS, embedding dim above the kernel's tile floor, matching devices) and raises a LIKUnsupportedError sentinel when the kernel cannot run; real kernel errors always propagate. Both paths treat padded tokens as exactly-zero embeddings rather than an explicit mask; ColQwen2 already zeroes padded positions through the attention mask, so the scores match.

The design deliberately matches PyLate's integration (lightonai/pylate#222), so using both libraries means one mental model: the extra is [lik], the backend module split is the same, and COLPALI_SCORES_BACKEND (read per call) mirrors PYLATE_SCORES_BACKEND with the same three values: auto (default) uses LIK when eligible and silently falls back to torch, torch forces the reference, and lik is strict, raising LIKUnsupportedError instead of falling back.

Results

The kernel itself is much faster than the einsum it replaces. LIK's 0.4.1 benchmarks isolate the loss head on ColPali-like shapes (Lq=32, Ld=1030, no encoder, forward+backward at matched numerics): the speedup climbs 1.13× → 4.31× as B × n_neg grows (2.50× at B256×n8, 4.31× at B256×n16), with ~25–30% lower peak memory on the head. In LIK's own words, this is "the throughput the encoder hides" in end-to-end training.

End-to-end, the model forward dilutes that speedup to parity. A ColQwen2 training step is dominated by the 2B-parameter doc/query towers; MaxSim is a few milliseconds of a >1 s step. We measure per-step parity (B=64: 7.19 vs 7.23 samples/s), and LIK's own end-to-end ColQwen2 table shows the same 0.97–1.02×. The dilution is mechanical, not a kernel property: on a 17M encoder, where MaxSim is a bigger slice of the step, the same kernel shows up as a 1.1–1.3× end-to-end speedup.

What survives at ColQwen2 scale is the memory win, measured at the op level. We instrumented the dispatcher during real ColQwen2 training steps, then replayed each recorded shape on an isolated graph where the op's backward can be measured exactly (the replayed forward numbers match the in-train ones to the MiB). The VRAM attributable to MaxSim:

batch size vanilla: held vanilla: bwd spike vanilla: total LIK: held LIK: bwd spike LIK: total
16 32 MiB 73 MiB 106 MiB 1 MiB 7 MiB 7 MiB
32 151 MiB 339 MiB 489 MiB 1 MiB 13 MiB 14 MiB
64 615 MiB 1.35 GiB 1.95 GiB 3 MiB 26 MiB 29 MiB
128 2.40 GiB 5.41 GiB 7.81 GiB 9 MiB 53 MiB 62 MiB
maxsim_vram

The score grid is fp32 in practice (autocast computes the embedding L2-norm in fp32 and the division promotes, so the loss runs on fp32 embeddings). At B=128 the [B,B,Lq,Ld] tensor is 2.4 GiB, held from the op's forward until its backward, where the op spikes another 2.25× that (the grid's gradient plus the amax scatter): 5.4 GiB. LIK holds only the [B,B] output and its backward allocates only the input gradients (dominated by grad_D), so its footprint grows linearly in B instead of quadratically: 62 MiB total at B=128, a 129× reduction.

That op footprint is what sets the batch-size ceiling. Sweeping per_device_train_batch_size until OOM (ColQwen2 + LoRA, grad-checkpointing, bf16, 80 GB H100): whole-step peak allocated VRAM is identical while both fit, then splits at B=128.

batch size vanilla (LIK off) LIK on
16 10.9 GiB 10.9 GiB
32 17.1 GiB 17.1 GiB
64 29.5 GiB 29.5 GiB
128 OOM 54.4 GiB
256 OOM OOM
batch_sweep

Note

Vanilla and LIK look identical in the VRAM table up to B=64 because the score tensors are freed before the peak: the op's backward runs first, then the model backward where the peak lives. At B=128 vanilla dies not because it uses more peak memory, but because its score grid needs multi-GiB contiguous blocks that memory fragmentation makes impossible to satisfy (the observed OOM is a 1.81 GiB request failing while 25 GiB sit reserved but unallocated). LIK's 62 MiB fits in whatever scraps remain.

Vanilla maxes out at B=64, LIK at B=128 (2× headroom). At B=256 both paths OOM: pushing 256 pages of ~768 visual tokens each through the 2B doc tower is the limit, regardless of the score tensor. The first steps pay a one-time Triton autotune warmup that amortizes over a full run. The loss matches the torch reference within bf16 noise.

Full sweep table

1× H100, vidore/colqwen2-base, ColbertPairwiseCELoss, grad-checkpointing on, LIK 0.4.1. Fresh process per (B, backend) so an OOM is isolated.

batch size vanilla fits? vanilla peak alloc (MiB) vanilla samples/s LIK fits? LIK peak alloc (MiB) LIK samples/s
16 yes 11134 5.50 yes 11133 2.51
32 yes 17497 6.85 yes 17497 5.76
64 yes 30247 7.23 yes 30247 7.19
128 no (OOM) 54.0 GiB pre-OOM OOM yes 55691 7.98
256 no (OOM) n/a OOM no (OOM) n/a OOM
384 no (OOM) n/a OOM no (OOM) n/a OOM
512 no (OOM) n/a OOM no (OOM) n/a OOM

Throughput is from 4-step runs, so it is autotune-warmup-affected (LIK looks slower at B=16 only because warmup dominates 4 steps). The point that matters: LIK runs B=128 at 7.98 samples/s, which vanilla cannot reach.

Reproduce

The op-level VRAM harness and its result JSONs were added and removed within this PR's history; check out 1717e37 to get both. The harness wraps maxsim_inbatch during training to record the forward peak and the bytes held for backward, then replays each recorded shape on an isolated graph to bracket the op's backward exactly (a grad hook cannot bracket it in-train: it fires as a pre-hook of the producing node, after the whole doc-tower backward).

git checkout 1717e37                                                  # op-level harness + results present here
pip install -e ".[train,lik]"
sky launch -c colpali-lik-vram bench_lik/sky_maxsim_vram_sweep.yaml   # B in {16..128} x {auto, torch}
rsync -avP 'colpali-lik-vram:sky_workdir/bench_lik/results/' ./bench_lik/results/
python bench_lik/summarize_maxsim_vram.py --results-dir bench_lik/results   # op table + maxsim_vram.png

The whole-step batch-size sweep (B up to 512) ran on an earlier harness iteration at 2749bd5 (sky_batch_sweep.yaml + summarize_sweep.py, pre-rebase commit kept accessible by GitHub). The CUDA test runner (scripts/sky_test_lik.yaml) lives at 7ae3402; the slow suite itself stays in-tree (pytest -m slow tests/utils/test_maxsim_cuda.py on a CUDA Ampere+ host).

Force a single run onto a backend with COLPALI_SCORES_BACKEND=auto|torch|lik (lik errors instead of silently falling back).

Next steps

  • When this PR is merged, drop the corresponding patches in the late-interaction-kernels repository.
  • Cut a new colpali-engine release once merged (0.3.17 if nothing else lands in between) so the [lik] extra is installable from PyPI.

@tonywu71 tonywu71 changed the title Add support for late interaction kernels Add support for late-interaction-kernels (LIK) May 22, 2026
Copy link
Copy Markdown
Collaborator

@ManuelFay ManuelFay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super cool !
I think we need to run a real training with this kernel to see before merging what the gains are: how much can we increase batch size by, how much more speed do we gain?

this would also be a good opportunity to reverify some of the training scripts and the doc here to make sure it s starightforward to do (i am not so sure).

@mlconti1 @antoineedy ?

@tonywu71 tonywu71 force-pushed the add-support-for-late-interaction-kernels branch from 4a733b6 to 1223fbb Compare June 4, 2026 17:55
tonywu71 added 4 commits June 4, 2026 22:45
- Add maxsim dispatcher (maxsim_inbatch, maxsim_kd) with a pure-torch
  einsum reference and a lazily-imported LIK backend
- Mirror PyLate's design: [lik] extra, COLPALI_SCORES_BACKEND env var
  (auto/torch/lik) read per call, LIKUnsupportedError sentinel
- Route score_multi_vector and the five ColBERT losses through the
  dispatcher (negative-doc losses via LIK's kd_layout)
- Add CPU dispatch tests plus CUDA parity and training-smoke tests
- Fix transformers-5.x trainer breakage (_get_train_sampler signature,
  single-dataset compute_loss prefixes)
- Document the extra and the backend toggle in README and CHANGELOG
- Add bench_train.py: runs training steps with the maxsim dispatcher
  instrumented (per-call forward peak and bytes held for backward),
  then replays each recorded shape on an isolated graph to bracket
  the op's backward exactly
- Add SkyPilot sweep over B in {16..128} x {auto, torch}, fresh
  process per cell so an OOM is isolated
- Add summarizer emitting the markdown table and the log-log plot
- Add train-subset loaders so runs skip the full 52 GB train set
8 cells from a 1x H100 run (LIK 0.4.1, ColQwen2 + LoRA): per-op
forward/held/backward VRAM plus whole-step peaks; the vanilla B=128
cell records the fragmentation OOM message.
Keep the final tree lean: the harness and results stay reachable at
the two prior commits, referenced from the PR description.
@tonywu71 tonywu71 force-pushed the add-support-for-late-interaction-kernels branch from cacf178 to 1d3363c Compare June 4, 2026 20:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants