late-interaction-kernels · design walkthrough
how it works

Scoring every query–document token pair without materialising the table

Late-interaction retrieval models compare every query token to every document token. A direct implementation builds a four-dimensional similarity tensor that overflows GPU memory at modern context lengths. The kernels in this library compute the same result without ever storing that tensor, by streaming tiles through on-chip memory and folding the reductions into the same pass.

1The MaxSim score

MaxSim is the scoring function ColBERT-style late-interaction models use. The inputs are two batches of token-level embeddings:

  • $\Qm \in \mathbb{R}^{N_q \times L_q \times d}$: $N_q$ queries, each with $L_q$ token vectors of dimension $d$.
  • $\Dm \in \mathbb{R}^{N_d \times L_d \times d}$: same shape, for documents.

For every (query $i$, document $j$) pair, the score is:

$$\operatorname{score}[i, j] \;=\; \sum_{s=1}^{L_q} \; \max_{t=1}^{L_d} \; \langle \Qm_{i, s}, \Dm_{j, t} \rangle.$$

For each query token, find the document token it matches best (highest inner product), then sum those best matches across all query tokens. The result is a score matrix in $\mathbb{R}^{N_q \times N_d}$. Optional boolean masks drop padding tokens from the sum and from the inner max.

The problem to solve. Computing this naively builds the full $[N_q, N_d, L_q, L_d]$ similarity tensor first, then reduces. At ColPali scale ($L_q = L_d = 1024$, batch of 64 pairs) that tensor is 8 GB in fp16. Each element is written once and read twice, so the GPU spends most of its time on HBM traffic, not compute. The kernels in this library compute the same result without ever materialising that tensor, by streaming $B_q \times B_d$ tiles through on-chip memory and folding the reductions into the same pass.

2Programs and the launch grid

A GPU kernel is a small function the device runs in many parallel instances. Each instance is a program in Triton's vocabulary, or a thread block in CUDA's. Writing a kernel comes down to two design choices: how many programs to launch, and what one program computes.

The MaxSim forward uses a grid of $N_q \cdot N_d$ programs, one per query–document pair. Program $(i, j)$ pulls the embeddings of query $i$ and document $j$ from HBM, accumulates a single fp32 scalar, and writes it to position $(i, j)$ of the score matrix. All $N_q \cdot N_d$ programs run concurrently on the GPU's streaming multiprocessors, with the hardware scheduling them onto whatever SMs are free.

j=01 23 45 67 i=0 1 2 3 0,00,10,20,30,40,50,60,7 1,01,11,21,31,41,51,61,7 2,02,12,22,32,42,52,62,7 3,03,13,23,33,43,53,63,7
$N_q = 4$, $N_d = 8$: 32 programs, one per output cell, all running in parallel. Each program produces a single fp32 number.

3Inside one program

Fix a single program at $(i, j)$. It has the $i$-th query and the $j$-th document available in HBM. The most direct implementation of its scalar output looks like

$$\Sm \;=\; \Qm_{i,\,\cdot}\, \Dm_{j,\,\cdot}^{\top} \;\in\; \mathbb{R}^{L_q \times L_d}, \qquad \operatorname{score}[i,j] \;=\; \sum_{s} \max_{t} \Sm_{s, t}.$$

Building $\Sm$ even for a single pair wastes most of the values: every row of $\Sm$ contributes only one number to the final score, the row maximum. The kernel never builds it. Instead the program walks the output in tiles. Pick block sizes $B_q$ and $B_d$ (usually 32 or 64). The program loops over $\lceil L_q / B_q \rceil$ query tiles and, inside each, over $\lceil L_d / B_d \rceil$ document tiles. The largest thing that ever exists is one $B_q \times B_d$ slice of $\Sm$, held in registers for the duration of a single tile-matmul.

One inner-loop iteration · all of S lives in registers Q tile B_q × d · D tile d × B_d = S tile = Q · D B_q × B_d row max tile_max B_q max(m, ·) m (running) B_q · fp32 m persists across inner-loop iterations · S and tile_max are recomputed each step
Q tile · loaded once per outer iter D tile · loaded once per inner iter S tile · registers only running max $m$ · fp32 registers

Each inner-loop iteration loads a fresh $\Dm$ tile, multiplies it against the resident $\Qm$ tile on the tensor cores, takes the row-wise max of the resulting $\Sm$ tile, and merges it into $\Mm$. After the inner loop finishes, $\Mm$ holds the true per-row maximum over the full $L_d$ axis, and the program adds $\sum \Mm$ to its scalar accumulator. After the outer loop finishes, that accumulator is written to HBM. One pair, one scalar, one store.

This is the entire optimisation. A naive implementation writes the full $\Sm$ tensor to HBM and reads it back twice (once for the max, once for the sum). The fused program reads $\Qm$ and $\Dm$ once, keeps every $\Sm$ slice on-chip, and writes one fp32 scalar per pair. The next section makes the cost difference concrete.

4Why the naive form is bandwidth-bound

The textbook implementation of MaxSim, the one inside PyLate's reference code, separates the two reductions:

S = einsum("nsd, mtd -> nmst", Q, D)   # [Nq, Nd, Lq, Ld]  ← lives in HBM
M = S.max(dim=-1).values               # [Nq, Nd, Lq]      ← max over t
scores = M.sum(dim=-1)                 # [Nq, Nd]          ← sum over s

The full similarity tensor $\Sm$ is written to HBM, then read back to take the row-wise max, then read again for the sum. Its memory footprint is

$$\text{bytes}(\Sm) \;=\; N_q \cdot N_d \cdot L_q \cdot L_d \cdot \text{sizeof(float)}.$$

Scaling is quadratic in the sequence length. At ColPali shapes ($N_q = N_d = 64$, $L_q = L_d = 1024$, fp32) the tensor is $16$ GB; in fp16 it is $8$ GB. Training builds three or four such intermediates per step. The chart shows the same number plotted as a function of the sequence length:

80 GB · H100 limit Naive · materialises S Fused · streams tiles 256 768 1280 1792 2048 sequence length L (L_q = L_d) 0 8 GB 16 GB 24 GB peak HBM for one S
S in HBM, fp16, $N_q = N_d = 64$ fused kernel — only the $[N_q, N_d]$ score matrix

Beyond the memory pressure, the naive version is bandwidth-bound: each element of $\Sm$ is written to HBM once and read twice for a single multiply-add per element. The arithmetic intensity is far below the H100 ridge point, so the SMs spend most of their time waiting on memory. Appendix B works the FLOP-per-byte numbers out and places both implementations on the roofline.

5Fused forward: tile, stream, accumulate

Spelling out the loop nest sketched in section 3: one Triton program per $(q\text{-batch}, d\text{-batch})$ pair, document tiles of $\text{BLOCK\_D}$ rows nested inside query tiles of $\text{BLOCK\_Q}$ rows. The full similarity tensor $\Sm$ never exists in HBM. Only the $\text{BLOCK\_Q} \times \text{BLOCK\_D}$ slice does, and it lives in SRAM for the duration of one tile-matmul.

# one program per (q_batch, d_batch)
score ← 0
for q_start in range(0, Lq, BLOCK_Q):
    Q_tile ← Q[q_batch, q_start : q_start + BLOCK_Q]      # SRAM
    m     ← −∞                                            # [BLOCK_Q] in registers
    for d_start in range(0, Ld, BLOCK_D):
        D_tile ← D[d_batch, d_start : d_start + BLOCK_D]  # SRAM
        S_tile  = Q_tile @ D_tileᵀ                        # tensor cores, fp32 acc
        S_tile ← mask(S_tile, d_active, −∞)               # fused-in masking
        m      ← maximum(m, rowmax(S_tile))               # online max
    score += sum(m)                                       # contribution from this Q tile
scores[q_batch, d_batch] ← score                           # only this scalar is written

The structure is the same outer-product tiling that FlashAttention uses, with one simplification: where FlashAttention has to maintain a running max and a running sum of exponentials for softmax, MaxSim only has a running max. The recurrence is exact and there is nothing to rescale.

Per program, SRAM holds only the operand tiles $(\text{BLOCK\_Q} + \text{BLOCK\_D}) \cdot d$ in low precision — Triton double-buffers them so the next $\Dm$ tile loads while the current GEMM runs. The score tile $\Sm$ and the running max $\Mm$ live in fp32 registers, never in SRAM: tl.dot returns its accumulator straight into the register file, and the row-reduction happens there. At typical block sizes ($\text{BLOCK\_Q} = \text{BLOCK\_D} = 64$, $d = 128$), SMEM use is around 32 KiB per program, comfortably below the H100's 228 KiB per SM, leaving room for multiple concurrent programs and the autotuner to trade block size against occupancy.

Long-query chunking

For ColPali-scale queries ($L_q \approx 1024$ visual patches), the outer loop in the pseudo-code above serialises $\sim\!8$ query-tile iterations inside one program, leaving the rest of the SM grid idle. Because MaxSim is a sum over query tokens of a per-token max, it decomposes exactly:

$$\text{score}[i, j] = \sum_{c=0}^{n_c - 1} \sum_{s \in \text{chunk}_c} \max_t \langle \Qm_{i, s},\, \Dm_{j, t} \rangle.$$

maxsim() exploits this whenever $L_q > 512$ by reshaping $\Qm \in \mathbb{R}^{N_q \times L_q \times d}$ into $\Qm_c \in \mathbb{R}^{N_q n_c \times 128 \times d}$ (128 tokens per chunk, $n_c = \lceil L_q / 128 \rceil$), running the kernel over the larger batch, and summing the chunk scores back per original query. The reshape and sum are plain tensor ops, so autograd flows through them unchanged and no custom backward is needed.

Two gains at once: more programs on the grid (better SM utilisation) and the kernel always sees $L_q = 128$, pinning the autotune cache to a small constant rather than one entry per distinct query length.

Splitting only pays off once the per-program serial loop dominates the cost of launching the extra programs. Below that crossover the inner loop is already short and the grid already full, so chunking would only add launch overhead, which is slower than not chunking at all. The split therefore fires only when $L_q > 512$; short queries (ColBERT $L_q \leq 32$, long-document $L_q \leq 512$) take the un-chunked path, which is already the faster one for them.

Step through the algorithm

Below is a toy execution with $L_q = 4$, $L_d = 12$, $d = 4$, $\text{BLOCK\_Q} = 4$ (one outer iteration) and $\text{BLOCK\_D} = 4$ (three inner iterations). Use the controls to step through; the counters on the right track HBM bytes moved versus the naive baseline.

State (current tile)

q_start
d_start
running max m
partial sum Σm
score (final)

HBM traffic (one program)

reads
0 B
writes
0 B
total · fused
0 B
naive baseline
0 B
ratio
step 0 / 0
Q (active) D (active) S tile · ephemeral, SRAM only running max m consumed

6The online-max recurrence

To eliminate the score tile entirely, the inner loop replaces every value of $\Sm$ with its contribution to the per-row maximum, which we update incrementally. After consuming the $k$-th tile, the running max obeys

$$\Mm^{(k)} \;=\; \max\!\bigl(\Mm^{(k-1)},\;\; \operatorname{rowmax}(\Sm^{(k)})\bigr),$$

with $\Mm^{(0)} = -\infty$. After all tiles are consumed, $\Mm$ is the true per-row maximum and $\sum_s \Mm_s$ is the score contribution for the current $\text{BLOCK\_Q}$. Both quantities are computed in fp32 in registers; the GEMM uses bf16/fp16 operands with fp32 accumulation, matching the tensor-core native path.

Why this is bitwise simpler than FlashAttention. Online softmax has to track $\Mm$ and the running sum of exponentials $\ell$, and rescale both whenever the running max changes: the rescalers are the $\exp(\Mm^{(k-1)} - \Mm^{(k)})$ factors in FA-1. Online max has no such rescaler: $\max$ is idempotent, so the update is just an elementwise max and the result is identical to the offline computation.

Worked example

Three doc-tiles, one query row. The argmax index is recorded for backward.

Tile$\Sm$ row (values seen this tile)tile maxrunning max after tileargmax (global $t$)
0[0.42, 0.11, 0.30, 0.18]0.420.420
1[0.20, 0.55, 0.05, 0.31]0.550.555
2[0.49, 0.40, 0.50, 0.22]0.500.555

Masked positions are written as $-\infty$ before the reduction, so they cannot influence the argmax even when scores would otherwise be negative. This is stricter than PyLate's reference, which post-multiplies by a $0/1$ mask, and matches the masking discipline used by flash-maxsim and FlashAttention.

7Backward: argmax-only gradients

$\max$ is sub-differentiable: only the argmax position carries a gradient, so the backward pass is simpler than for softmax. With $g \in \mathbb{R}^{N_q \times N_d}$ the upstream gradient on the score, a short chain-rule walk (Appendix C) gives:

$$\nabla_{\Qm_{i, s}} \;=\; m^q_{i,s} \cdot \sum_{j} g_{i,j} \cdot \Dm_{j,\; \operatorname{argmax}_t \langle \Qm_{i,s}, \Dm_{j,t}\rangle},$$ $$\nabla_{\Dm_{j, t}} \;=\; \sum_{(i, s)\, :\, \operatorname{argmax}[i, j, s] \,=\, t} m^q_{i,s} \cdot g_{i, j} \cdot \Qm_{i, s}.$$

To avoid recomputing the forward, the forward kernel optionally writes an $[N_q \cdot N_d, L_q]$ int32 buffer of argmax indices (4 MB for a typical training batch). $\nabla_{\Qm}$ is then embarrassingly parallel: one program per $(i, s)$ gathers $\Dm_{j, \text{argmax}}$ across $j$ and produces a single output row.

Why $\nabla_{\Dm}$ is the hard one

Picture the scatter for a fixed doc-batch index $j$. Each query-token pair $(i, s)$ contributes one row into $\nabla_{\Dm}[j, \cdot]$ at row $\operatorname{argmax}[i, j, s]$. Whether two pairs collide is a runtime property of the argmax. The kernel cannot know it statically.

(i=0, s=0)
(i=0, s=1)
(i=0, s=2)
(i=1, s=0)
(i=1, s=1)
(i=1, s=2)
grad_D[j, 5]4 writers
grad_D[j, 12]1 writer
grad_D[j, 18]0 writers
grad_D[j, 23]1 writer
grad_D[j, L_d−1]0 writers
Six source pairs scattering into the grad_D[j, ·] rows for one doc-batch. The bucket sizes are data-dependent: row 5 is a hot spot (4 source pairs land on it); most rows have 0 or 1. Both the per-row fan-in and the empty-row pattern matter for which kernel wins.

The library offers two reduction strategies to handle this contention; auto (the default) chooses between them:

MethodBitwise reproducibleWhen it is picked
unified no (atomic, ≤1e-6 rel) Default for cross-products. Single-pass fused $\nabla_{\Qm} + \nabla_{\Dm}$. Hoists $\Qm_{i, s}$ out of the doc-batch loop, roughly halving HBM read traffic; $\nabla_{\Dm}$ writes use fp32 atomic add.
lowmem yes Picked for gradient-heavy shapes — KD / hard-negative layouts and high-contention squares ($N_q, N_d \geq 256$, $L_q \leq 64$). Destination-owned: accumulates in fp32 registers and writes $\nabla_{\Qm} / \nabla_{\Dm}$ directly in the input dtype, so no full-size fp32 buffer, no fp32→bf16 transient, no atomics. Roughly halves backward peak memory and is the deterministic path.
How lowmem avoids atomics. The unified path lives with the scatter pattern at runtime and uses fp32 atomic adds into a full-size buffer. lowmem inverts it: each Triton program owns one $(j, t)$ output row and reduces the contributing $(i, s)$ pairs from the saved argmax in fp32 registers via a one-hot matmul, then writes once in the input dtype — no atomics, bitwise-stable across launches, and no fp32→bf16 transient. That is why auto routes the high-contention $N_q, N_d \geq 256 \wedge L_q \leq 64$ regime (and every KD / hard-negative layout) there.

All paths use Triton's stable argmax (lowest-index tie-break) on the forward, so only the $\nabla_{\Dm}$ reduction order distinguishes them numerically.

Backward launch-param autotuning

The backward kernels have no block-tiling dimensions to sweep. Each is one Triton program per output row streaming a single embedding vector, so they are autotuned over num_warps and num_stages only, keyed like the forward so the cache holds one entry per training regime.

8Variants in the library

The streaming-and-never-storing structure is the entire optimisation. The rest of the library applies the same idea to a few adjacent workloads:

  • Variable-length / packed input (maxsim_varlen). Real corpora have ragged lengths; padding to $L_d^{\max}$ wastes roughly half the FLOPs. The packed kernel consumes $[\text{total\_tokens}, d]$ tensors with FlashAttention-style cu_seqlens offsets and runs the same tiling.
  • Fused $D$-side projection (maxsim_from_hidden). For corpora stored as ModernBERT hidden states, the kernel folds $\texttt{Linear} \rightarrow \texttt{L2-normalize} \rightarrow \texttt{MaxSim}$ into a single pass. The projected embedding tensor is never written to HBM. The training backward gathers only the winning positions and recomputes the projection locally.
  • PLAID / ColBERTv2 (maxsim_residual, maxsim_residual_varlen). The doc embeddings live on disk as $(\text{centroid index}, \text{quantised residual})$. The kernel decompresses on the fly in SRAM, optionally L2-normalises, and runs MaxSim, all in one program with no dense decompressed tensor ever materialised.
  • FP8 MaxSim (maxsim_inference_fp8). On Hopper, the GEMM uses the FP8 tensor cores at twice the bf16 throughput; the reduction stays in fp32 to preserve range.

Different inputs, same recipe: keep large intermediates on-chip, run the reductions in registers, and ship only the final scores to HBM.

AGPU memory hierarchy

Modern GPUs are dramatically faster at arithmetic than at moving data. An H100 sustains roughly $990$ TFLOP/s of bf16 tensor-core throughput against $3.3$ TB/s of HBM bandwidth, a ratio of about $300$ FLOPs per byte. Any kernel that moves data more aggressively than it computes on it becomes bandwidth-bound, and the SMs sit idle waiting for memory. The trick to writing a fast kernel is therefore not faster arithmetic, but fewer round-trips to HBM. The hierarchy in play:

LevelSizeBandwidthWhere it lives
Registers 256 KB / SM ~20 TB/s Inside each thread
SRAM (shared memory) 228 KB / SM (H100) ~19 TB/s On-chip, per SM
L2 cache 50 MB (H100) ~5 TB/s On-chip, chip-wide
HBM 80 GB (H100) ~3.3 TB/s Off-chip, on the package

SRAM is roughly six times faster than HBM, and registers are another four times faster than SRAM. A fused kernel chains several logical operations back-to-back inside SRAM and registers, writing only the final result to HBM. That is the optimisation strategy used throughout this library.

BArithmetic intensity

Whether a kernel is bound by compute or by memory is captured by a single number, the arithmetic intensity, $\mathrm{AI} = \text{FLOPs} / \text{HBM bytes moved}$. Plot achievable throughput against $\mathrm{AI}$ and you get the roofline: a bandwidth slope on the left, a compute plateau on the right, meeting at the ridge point $\mathrm{AI}^{*} = \text{peak compute} / \text{peak bandwidth}$. For an H100 in bf16, $\mathrm{AI}^{*} \approx 295$ FLOPs/byte. Below that you are memory-bound, above it compute-bound.

Scoring one $(\Qm, \Dm)$ pair is $2 L_q L_d \, d$ FLOPs of matmul plus a few cheap reductions. The two implementations differ only in what they push through HBM. The naive form materialises the score tile $\Sm$, so its traffic is dominated by the $L_q \times L_d$ surface (read + write, fp16):

$$\mathrm{AI}_{\text{naive}} \;\approx\; \frac{2 L_q L_d \, d}{4 L_q L_d} \;=\; \frac{d}{2}.$$

The fused form never writes $\Sm$; the only HBM traffic is the operand reads $\Qm$ and $\Dm$:

$$\mathrm{AI}_{\text{fused}} \;\approx\; \frac{2 L_q L_d \, d}{2 (L_q + L_d) \, d} \;=\; \frac{L_q L_d}{L_q + L_d}.$$

At ColPali shapes ($L_q = L_d = 1024$, $d = 128$), the naive form sits at about 64 FLOPs/byte, four to five times below the ridge, so the tensor cores idle while HBM works flat out. The fused form sits at about 512 FLOPs/byte, comfortably above the ridge, where the kernel is finally limited by compute. The same workload moves between regimes purely by changing what crosses HBM:

MEMORY-BOUND COMPUTE-BOUND ridge AI ≈ 295 naive ≈ 64 FLOPs/B fused ≈ 512 FLOPs/B 1 10 100 1000 arithmetic intensity (FLOPs / HBM byte, log scale)

Fusing is not faster because the arithmetic is cheaper. It is the same matmul. It is faster because the matmul never has to wait for memory.

CDeriving the backward gradients

Section 7 quotes the closed forms for $\nabla_{\Qm}$ and $\nabla_{\Dm}$. They drop out of the forward formula in three short steps: resolve the $\max$, apply the chain rule, then read off the two cases.

Resolving the max

Once the forward has run, define the winner for each query token:

$$t^{\star}(i, j, s) \;=\; \arg\max_{t} \, \langle \Qm_{i, s},\, \Dm_{j, t}\rangle.$$

The score then collapses into a plain sum of inner products at the winning slots — the $\max$ has been resolved into a concrete index:

$$\text{score}(i, j) \;=\; \sum_{s} \langle \Qm_{i, s},\, \Dm_{j,\, t^{\star}(i, j, s)}\rangle.$$

$\max$ is non-smooth, but at any non-tied point its subgradient equals the gradient of whichever term achieves the maximum. So locally, for a perturbation small enough that no argmax flips, $\text{score}(i, j)$ is just a sum of bilinear terms in $\Qm$ and $\Dm$, and $t^{\star}$ is a constant we read out of the saved buffer.

Chain rule

Let $g_{i, j} = \partial L / \partial \text{score}(i, j)$ be the upstream gradient. For any parameter $X$:

$$\frac{\partial L}{\partial X} \;=\; \sum_{i, j} g_{i, j} \cdot \frac{\partial\, \text{score}(i, j)}{\partial X}.$$

Case 1: $\partial / \partial \Qm_{i, s, k}$

$\text{score}(i, j)$ depends on the row $\Qm_{i, s}$ only when both indices match, and only the $s$-th term in the sum contains the component $\Qm_{i, s, k}$. That term is $\langle \Qm_{i, s},\, \Dm_{j,\, t^{\star}(i, j, s)}\rangle$, so the partial is the matching entry of $\Dm$:

$$\frac{\partial\, \text{score}(i, j)}{\partial \Qm_{i, s, k}} \;=\; \Dm_{j,\, t^{\star}(i, j, s),\, k}.$$

Summing over $j$, the only free index left in the chain rule, gives $\nabla_{\Qm_{i, s}}$. Each $j$ picks exactly one row of $\Dm$ out of $L_d$, a pure gather with no collisions.

Case 2: $\partial / \partial \Dm_{j, t, k}$

$\text{score}(i, j_0)$ depends on $\Dm_{j_0, \cdot, \cdot}$ only when the doc-batch indices match. Within $\text{score}(i, j_0) = \sum_s \langle \Qm_{i, s},\, \Dm_{j_0,\, t^{\star}(i, j_0, s)}\rangle$, the entry $\Dm_{j_0, t, k}$ appears in the $s$-th term iff $t^{\star}(i, j_0, s) = t$. For each such $s$ the partial is $\Qm_{i, s, k}$, and multiple $s$'s may collide on the same $t$ — they all add:

$$\frac{\partial\, \text{score}(i, j)}{\partial \Dm_{j, t, k}} \;=\; \sum_{s\,:\, t^{\star}(i, j, s) \,=\, t} \Qm_{i, s, k}.$$

Chaining over $i$ with $j$ fixed folds the two sums into one set:

$$\frac{\partial L}{\partial \Dm_{j, t, k}} \;=\; \sum_{(i, s)\,:\, t^{\star}(i, j, s) \,=\, t} g_{i, j} \cdot \Qm_{i, s, k}.$$

That set $\{(i, s) : t^{\star}(i, j, s) = t\}$ is exactly the bucket the lowmem kernel reduces per output row, and the same contention pattern the unified path resolves with atomic_add.

The mask falls out for free

The forward gates masked query tokens out of the outer sum, so masked positions contribute nothing to the score and their partials are zero on both sides. The $m^q_{i, s}$ factor in the closed-form expressions is just that gating carried through.