Scoring every query–document token pair without materialising the table
Late-interaction retrieval models compare every query token to every document token. A direct implementation builds a four-dimensional similarity tensor that overflows GPU memory at modern context lengths. The kernels in this library compute the same result without ever storing that tensor, by streaming tiles through on-chip memory and folding the reductions into the same pass.
1The MaxSim score
MaxSim is the scoring function ColBERT-style late-interaction models use. The inputs are two batches of token-level embeddings:
- $\Qm \in \mathbb{R}^{N_q \times L_q \times d}$: $N_q$ queries, each with $L_q$ token vectors of dimension $d$.
- $\Dm \in \mathbb{R}^{N_d \times L_d \times d}$: same shape, for documents.
For every (query $i$, document $j$) pair, the score is:
$$\operatorname{score}[i, j] \;=\; \sum_{s=1}^{L_q} \; \max_{t=1}^{L_d} \; \langle \Qm_{i, s}, \Dm_{j, t} \rangle.$$
For each query token, find the document token it matches best (highest inner product), then sum those best matches across all query tokens. The result is a score matrix in $\mathbb{R}^{N_q \times N_d}$. Optional boolean masks drop padding tokens from the sum and from the inner max.
2Programs and the launch grid
A GPU kernel is a small function the device runs in many parallel instances. Each instance is a program in Triton's vocabulary, or a thread block in CUDA's. Writing a kernel comes down to two design choices: how many programs to launch, and what one program computes.
The MaxSim forward uses a grid of $N_q \cdot N_d$ programs, one per query–document pair. Program $(i, j)$ pulls the embeddings of query $i$ and document $j$ from HBM, accumulates a single fp32 scalar, and writes it to position $(i, j)$ of the score matrix. All $N_q \cdot N_d$ programs run concurrently on the GPU's streaming multiprocessors, with the hardware scheduling them onto whatever SMs are free.
3Inside one program
Fix a single program at $(i, j)$. It has the $i$-th query and the $j$-th document available in HBM. The most direct implementation of its scalar output looks like
$$\Sm \;=\; \Qm_{i,\,\cdot}\, \Dm_{j,\,\cdot}^{\top} \;\in\; \mathbb{R}^{L_q \times L_d}, \qquad \operatorname{score}[i,j] \;=\; \sum_{s} \max_{t} \Sm_{s, t}.$$
Building $\Sm$ even for a single pair wastes most of the values: every row of $\Sm$ contributes only one number to the final score, the row maximum. The kernel never builds it. Instead the program walks the output in tiles. Pick block sizes $B_q$ and $B_d$ (usually 32 or 64). The program loops over $\lceil L_q / B_q \rceil$ query tiles and, inside each, over $\lceil L_d / B_d \rceil$ document tiles. The largest thing that ever exists is one $B_q \times B_d$ slice of $\Sm$, held in registers for the duration of a single tile-matmul.
Each inner-loop iteration loads a fresh $\Dm$ tile, multiplies it against the resident $\Qm$ tile on the tensor cores, takes the row-wise max of the resulting $\Sm$ tile, and merges it into $\Mm$. After the inner loop finishes, $\Mm$ holds the true per-row maximum over the full $L_d$ axis, and the program adds $\sum \Mm$ to its scalar accumulator. After the outer loop finishes, that accumulator is written to HBM. One pair, one scalar, one store.
4Why the naive form is bandwidth-bound
The textbook implementation of MaxSim, the one inside PyLate's reference code, separates the two reductions:
S = einsum("nsd, mtd -> nmst", Q, D) # [Nq, Nd, Lq, Ld] ← lives in HBM
M = S.max(dim=-1).values # [Nq, Nd, Lq] ← max over t
scores = M.sum(dim=-1) # [Nq, Nd] ← sum over s
The full similarity tensor $\Sm$ is written to HBM, then read back to take the row-wise max, then read again for the sum. Its memory footprint is
$$\text{bytes}(\Sm) \;=\; N_q \cdot N_d \cdot L_q \cdot L_d \cdot \text{sizeof(float)}.$$
Scaling is quadratic in the sequence length. At ColPali shapes ($N_q = N_d = 64$, $L_q = L_d = 1024$, fp32) the tensor is $16$ GB; in fp16 it is $8$ GB. Training builds three or four such intermediates per step. The chart shows the same number plotted as a function of the sequence length:
Beyond the memory pressure, the naive version is bandwidth-bound: each element of $\Sm$ is written to HBM once and read twice for a single multiply-add per element. The arithmetic intensity is far below the H100 ridge point, so the SMs spend most of their time waiting on memory. Appendix B works the FLOP-per-byte numbers out and places both implementations on the roofline.
5Fused forward: tile, stream, accumulate
Spelling out the loop nest sketched in section 3: one Triton program per $(q\text{-batch}, d\text{-batch})$ pair, document tiles of $\text{BLOCK\_D}$ rows nested inside query tiles of $\text{BLOCK\_Q}$ rows. The full similarity tensor $\Sm$ never exists in HBM. Only the $\text{BLOCK\_Q} \times \text{BLOCK\_D}$ slice does, and it lives in SRAM for the duration of one tile-matmul.
# one program per (q_batch, d_batch)
score ← 0
for q_start in range(0, Lq, BLOCK_Q):
Q_tile ← Q[q_batch, q_start : q_start + BLOCK_Q] # SRAM
m ← −∞ # [BLOCK_Q] in registers
for d_start in range(0, Ld, BLOCK_D):
D_tile ← D[d_batch, d_start : d_start + BLOCK_D] # SRAM
S_tile = Q_tile @ D_tileᵀ # tensor cores, fp32 acc
S_tile ← mask(S_tile, d_active, −∞) # fused-in masking
m ← maximum(m, rowmax(S_tile)) # online max
score += sum(m) # contribution from this Q tile
scores[q_batch, d_batch] ← score # only this scalar is written
The structure is the same outer-product tiling that FlashAttention uses, with one simplification: where FlashAttention has to maintain a running max and a running sum of exponentials for softmax, MaxSim only has a running max. The recurrence is exact and there is nothing to rescale.
Per program, SRAM holds only the operand tiles
$(\text{BLOCK\_Q} + \text{BLOCK\_D}) \cdot d$ in low precision —
Triton double-buffers them so the next $\Dm$ tile loads while the
current GEMM runs. The score tile $\Sm$ and the running max $\Mm$
live in fp32 registers, never in SRAM: tl.dot
returns its accumulator straight into the register file, and the
row-reduction happens there. At typical block sizes
($\text{BLOCK\_Q} = \text{BLOCK\_D} = 64$, $d = 128$), SMEM use is
around 32 KiB per program, comfortably below the H100's 228 KiB per
SM, leaving room for multiple concurrent programs and the autotuner
to trade block size against occupancy.
Long-query chunking
For ColPali-scale queries ($L_q \approx 1024$ visual patches), the outer loop in the pseudo-code above serialises $\sim\!8$ query-tile iterations inside one program, leaving the rest of the SM grid idle. Because MaxSim is a sum over query tokens of a per-token max, it decomposes exactly:
$$\text{score}[i, j] = \sum_{c=0}^{n_c - 1} \sum_{s \in \text{chunk}_c} \max_t \langle \Qm_{i, s},\, \Dm_{j, t} \rangle.$$
maxsim() exploits this whenever $L_q > 512$ by
reshaping $\Qm \in \mathbb{R}^{N_q \times L_q \times d}$ into
$\Qm_c \in \mathbb{R}^{N_q n_c \times 128 \times d}$ (128 tokens per
chunk, $n_c = \lceil L_q / 128 \rceil$), running the kernel over the
larger batch, and summing the chunk scores back per original query.
The reshape and sum are plain tensor ops, so autograd flows through them
unchanged and no custom backward is needed.
Two gains at once: more programs on the grid (better SM utilisation) and the kernel always sees $L_q = 128$, pinning the autotune cache to a small constant rather than one entry per distinct query length.
Splitting only pays off once the per-program serial loop dominates the cost of launching the extra programs. Below that crossover the inner loop is already short and the grid already full, so chunking would only add launch overhead, which is slower than not chunking at all. The split therefore fires only when $L_q > 512$; short queries (ColBERT $L_q \leq 32$, long-document $L_q \leq 512$) take the un-chunked path, which is already the faster one for them.
Step through the algorithm
Below is a toy execution with $L_q = 4$, $L_d = 12$, $d = 4$, $\text{BLOCK\_Q} = 4$ (one outer iteration) and $\text{BLOCK\_D} = 4$ (three inner iterations). Use the controls to step through; the counters on the right track HBM bytes moved versus the naive baseline.
State (current tile)
HBM traffic (one program)
6The online-max recurrence
To eliminate the score tile entirely, the inner loop replaces every value of $\Sm$ with its contribution to the per-row maximum, which we update incrementally. After consuming the $k$-th tile, the running max obeys
$$\Mm^{(k)} \;=\; \max\!\bigl(\Mm^{(k-1)},\;\; \operatorname{rowmax}(\Sm^{(k)})\bigr),$$
with $\Mm^{(0)} = -\infty$. After all tiles are consumed, $\Mm$ is the true per-row maximum and $\sum_s \Mm_s$ is the score contribution for the current $\text{BLOCK\_Q}$. Both quantities are computed in fp32 in registers; the GEMM uses bf16/fp16 operands with fp32 accumulation, matching the tensor-core native path.
Worked example
Three doc-tiles, one query row. The argmax index is recorded for backward.
| Tile | $\Sm$ row (values seen this tile) | tile max | running max after tile | argmax (global $t$) |
|---|---|---|---|---|
| 0 | [0.42, 0.11, 0.30, 0.18] | 0.42 | 0.42 | 0 |
| 1 | [0.20, 0.55, 0.05, 0.31] | 0.55 | 0.55 | 5 |
| 2 | [0.49, 0.40, 0.50, 0.22] | 0.50 | 0.55 | 5 |
Masked positions are written as $-\infty$ before the reduction,
so they cannot influence the argmax even when scores would otherwise be
negative. This is stricter than PyLate's reference, which post-multiplies
by a $0/1$ mask, and matches the masking discipline used by
flash-maxsim and FlashAttention.
7Backward: argmax-only gradients
$\max$ is sub-differentiable: only the argmax position carries a gradient, so the backward pass is simpler than for softmax. With $g \in \mathbb{R}^{N_q \times N_d}$ the upstream gradient on the score, a short chain-rule walk (Appendix C) gives:
$$\nabla_{\Qm_{i, s}} \;=\; m^q_{i,s} \cdot \sum_{j} g_{i,j} \cdot \Dm_{j,\; \operatorname{argmax}_t \langle \Qm_{i,s}, \Dm_{j,t}\rangle},$$ $$\nabla_{\Dm_{j, t}} \;=\; \sum_{(i, s)\, :\, \operatorname{argmax}[i, j, s] \,=\, t} m^q_{i,s} \cdot g_{i, j} \cdot \Qm_{i, s}.$$
To avoid recomputing the forward, the forward kernel optionally writes an $[N_q \cdot N_d, L_q]$ int32 buffer of argmax indices (4 MB for a typical training batch). $\nabla_{\Qm}$ is then embarrassingly parallel: one program per $(i, s)$ gathers $\Dm_{j, \text{argmax}}$ across $j$ and produces a single output row.
Why $\nabla_{\Dm}$ is the hard one
Picture the scatter for a fixed doc-batch index $j$. Each query-token pair $(i, s)$ contributes one row into $\nabla_{\Dm}[j, \cdot]$ at row $\operatorname{argmax}[i, j, s]$. Whether two pairs collide is a runtime property of the argmax. The kernel cannot know it statically.
grad_D[j, ·]
rows for one doc-batch. The bucket sizes are data-dependent: row
5 is a hot spot (4 source pairs land on it); most rows have 0
or 1. Both the per-row fan-in and the empty-row pattern matter
for which kernel wins.
The library offers two reduction strategies to handle this
contention; auto (the default) chooses between them:
| Method | Bitwise reproducible | When it is picked |
|---|---|---|
unified |
no (atomic, ≤1e-6 rel) | Default for cross-products. Single-pass fused $\nabla_{\Qm} + \nabla_{\Dm}$. Hoists $\Qm_{i, s}$ out of the doc-batch loop, roughly halving HBM read traffic; $\nabla_{\Dm}$ writes use fp32 atomic add. |
lowmem |
yes | Picked for gradient-heavy shapes — KD / hard-negative layouts and high-contention squares ($N_q, N_d \geq 256$, $L_q \leq 64$). Destination-owned: accumulates in fp32 registers and writes $\nabla_{\Qm} / \nabla_{\Dm}$ directly in the input dtype, so no full-size fp32 buffer, no fp32→bf16 transient, no atomics. Roughly halves backward peak memory and is the deterministic path. |
lowmem avoids atomics.
The unified path lives with the scatter pattern at runtime
and uses fp32 atomic adds into a full-size buffer. lowmem
inverts it: each Triton program owns one $(j, t)$ output row
and reduces the contributing $(i, s)$ pairs from the saved argmax in
fp32 registers via a one-hot matmul, then writes once in the input
dtype — no atomics, bitwise-stable across launches, and no fp32→bf16
transient. That is why auto routes the high-contention
$N_q, N_d \geq 256 \wedge L_q \leq 64$ regime (and every KD /
hard-negative layout) there.
All paths use Triton's stable argmax (lowest-index
tie-break) on the forward, so only the $\nabla_{\Dm}$ reduction order
distinguishes them numerically.
Backward launch-param autotuning
The backward kernels have no block-tiling dimensions to sweep. Each is one
Triton program per output row streaming a single embedding vector, so they
are autotuned over num_warps and num_stages only,
keyed like the forward so the cache holds one entry per training regime.
8Variants in the library
The streaming-and-never-storing structure is the entire optimisation. The rest of the library applies the same idea to a few adjacent workloads:
-
Variable-length / packed input
(
maxsim_varlen). Real corpora have ragged lengths; padding to $L_d^{\max}$ wastes roughly half the FLOPs. The packed kernel consumes $[\text{total\_tokens}, d]$ tensors with FlashAttention-stylecu_seqlensoffsets and runs the same tiling. -
Fused $D$-side projection
(
maxsim_from_hidden). For corpora stored as ModernBERT hidden states, the kernel folds $\texttt{Linear} \rightarrow \texttt{L2-normalize} \rightarrow \texttt{MaxSim}$ into a single pass. The projected embedding tensor is never written to HBM. The training backward gathers only the winning positions and recomputes the projection locally. -
PLAID / ColBERTv2
(
maxsim_residual,maxsim_residual_varlen). The doc embeddings live on disk as $(\text{centroid index}, \text{quantised residual})$. The kernel decompresses on the fly in SRAM, optionally L2-normalises, and runs MaxSim, all in one program with no dense decompressed tensor ever materialised. -
FP8 MaxSim (
maxsim_inference_fp8). On Hopper, the GEMM uses the FP8 tensor cores at twice the bf16 throughput; the reduction stays in fp32 to preserve range.
Different inputs, same recipe: keep large intermediates on-chip, run the reductions in registers, and ship only the final scores to HBM.
AGPU memory hierarchy
Modern GPUs are dramatically faster at arithmetic than at moving data. An H100 sustains roughly $990$ TFLOP/s of bf16 tensor-core throughput against $3.3$ TB/s of HBM bandwidth, a ratio of about $300$ FLOPs per byte. Any kernel that moves data more aggressively than it computes on it becomes bandwidth-bound, and the SMs sit idle waiting for memory. The trick to writing a fast kernel is therefore not faster arithmetic, but fewer round-trips to HBM. The hierarchy in play:
| Level | Size | Bandwidth | Where it lives |
|---|---|---|---|
| Registers | 256 KB / SM | ~20 TB/s | Inside each thread |
| SRAM (shared memory) | 228 KB / SM (H100) | ~19 TB/s | On-chip, per SM |
| L2 cache | 50 MB (H100) | ~5 TB/s | On-chip, chip-wide |
| HBM | 80 GB (H100) | ~3.3 TB/s | Off-chip, on the package |
SRAM is roughly six times faster than HBM, and registers are another four times faster than SRAM. A fused kernel chains several logical operations back-to-back inside SRAM and registers, writing only the final result to HBM. That is the optimisation strategy used throughout this library.
BArithmetic intensity
Whether a kernel is bound by compute or by memory is captured by a single number, the arithmetic intensity, $\mathrm{AI} = \text{FLOPs} / \text{HBM bytes moved}$. Plot achievable throughput against $\mathrm{AI}$ and you get the roofline: a bandwidth slope on the left, a compute plateau on the right, meeting at the ridge point $\mathrm{AI}^{*} = \text{peak compute} / \text{peak bandwidth}$. For an H100 in bf16, $\mathrm{AI}^{*} \approx 295$ FLOPs/byte. Below that you are memory-bound, above it compute-bound.
Scoring one $(\Qm, \Dm)$ pair is $2 L_q L_d \, d$ FLOPs of matmul plus a few cheap reductions. The two implementations differ only in what they push through HBM. The naive form materialises the score tile $\Sm$, so its traffic is dominated by the $L_q \times L_d$ surface (read + write, fp16):
$$\mathrm{AI}_{\text{naive}} \;\approx\; \frac{2 L_q L_d \, d}{4 L_q L_d} \;=\; \frac{d}{2}.$$
The fused form never writes $\Sm$; the only HBM traffic is the operand reads $\Qm$ and $\Dm$:
$$\mathrm{AI}_{\text{fused}} \;\approx\; \frac{2 L_q L_d \, d}{2 (L_q + L_d) \, d} \;=\; \frac{L_q L_d}{L_q + L_d}.$$
At ColPali shapes ($L_q = L_d = 1024$, $d = 128$), the naive form sits at about 64 FLOPs/byte, four to five times below the ridge, so the tensor cores idle while HBM works flat out. The fused form sits at about 512 FLOPs/byte, comfortably above the ridge, where the kernel is finally limited by compute. The same workload moves between regimes purely by changing what crosses HBM:
Fusing is not faster because the arithmetic is cheaper. It is the same matmul. It is faster because the matmul never has to wait for memory.
CDeriving the backward gradients
Section 7 quotes the closed forms for $\nabla_{\Qm}$ and $\nabla_{\Dm}$. They drop out of the forward formula in three short steps: resolve the $\max$, apply the chain rule, then read off the two cases.
Resolving the max
Once the forward has run, define the winner for each query token:
$$t^{\star}(i, j, s) \;=\; \arg\max_{t} \, \langle \Qm_{i, s},\, \Dm_{j, t}\rangle.$$
The score then collapses into a plain sum of inner products at the winning slots — the $\max$ has been resolved into a concrete index:
$$\text{score}(i, j) \;=\; \sum_{s} \langle \Qm_{i, s},\, \Dm_{j,\, t^{\star}(i, j, s)}\rangle.$$
$\max$ is non-smooth, but at any non-tied point its subgradient equals the gradient of whichever term achieves the maximum. So locally, for a perturbation small enough that no argmax flips, $\text{score}(i, j)$ is just a sum of bilinear terms in $\Qm$ and $\Dm$, and $t^{\star}$ is a constant we read out of the saved buffer.
Chain rule
Let $g_{i, j} = \partial L / \partial \text{score}(i, j)$ be the upstream gradient. For any parameter $X$:
$$\frac{\partial L}{\partial X} \;=\; \sum_{i, j} g_{i, j} \cdot \frac{\partial\, \text{score}(i, j)}{\partial X}.$$
Case 1: $\partial / \partial \Qm_{i, s, k}$
$\text{score}(i, j)$ depends on the row $\Qm_{i, s}$ only when both indices match, and only the $s$-th term in the sum contains the component $\Qm_{i, s, k}$. That term is $\langle \Qm_{i, s},\, \Dm_{j,\, t^{\star}(i, j, s)}\rangle$, so the partial is the matching entry of $\Dm$:
$$\frac{\partial\, \text{score}(i, j)}{\partial \Qm_{i, s, k}} \;=\; \Dm_{j,\, t^{\star}(i, j, s),\, k}.$$
Summing over $j$, the only free index left in the chain rule, gives $\nabla_{\Qm_{i, s}}$. Each $j$ picks exactly one row of $\Dm$ out of $L_d$, a pure gather with no collisions.
Case 2: $\partial / \partial \Dm_{j, t, k}$
$\text{score}(i, j_0)$ depends on $\Dm_{j_0, \cdot, \cdot}$ only when the doc-batch indices match. Within $\text{score}(i, j_0) = \sum_s \langle \Qm_{i, s},\, \Dm_{j_0,\, t^{\star}(i, j_0, s)}\rangle$, the entry $\Dm_{j_0, t, k}$ appears in the $s$-th term iff $t^{\star}(i, j_0, s) = t$. For each such $s$ the partial is $\Qm_{i, s, k}$, and multiple $s$'s may collide on the same $t$ — they all add:
$$\frac{\partial\, \text{score}(i, j)}{\partial \Dm_{j, t, k}} \;=\; \sum_{s\,:\, t^{\star}(i, j, s) \,=\, t} \Qm_{i, s, k}.$$
Chaining over $i$ with $j$ fixed folds the two sums into one set:
$$\frac{\partial L}{\partial \Dm_{j, t, k}} \;=\; \sum_{(i, s)\,:\, t^{\star}(i, j, s) \,=\, t} g_{i, j} \cdot \Qm_{i, s, k}.$$
That set $\{(i, s) : t^{\star}(i, j, s) = t\}$ is exactly the bucket
the lowmem kernel reduces per output row, and the same
contention pattern the unified path resolves with
atomic_add.
The mask falls out for free
The forward gates masked query tokens out of the outer sum, so masked positions contribute nothing to the score and their partials are zero on both sides. The $m^q_{i, s}$ factor in the closed-form expressions is just that gating carried through.