LIK: Scoring every query–document token pair without materialising the table
Late-interaction retrieval models compare every query token to every document token. A direct implementation builds a four-dimensional similarity tensor that overflows GPU memory at modern context lengths. The kernels in the late-interaction-kernels (LIK) library compute the same result without ever storing that tensor, by streaming tiles through on-chip memory and folding the reductions into the same pass.
1The MaxSim score
MaxSim is the scoring function ColBERT-style late-interaction models use. The inputs are two batches of token-level embeddings:
- $\Qm \in \mathbb{R}^{N_q \times L_q \times d}$: $N_q$ queries, each with $L_q$ token vectors of dimension $d$.
- $\Dm \in \mathbb{R}^{N_d \times L_d \times d}$: same shape, for documents.
For every (query $i$, document $j$) pair, the score is:
$$\operatorname{score}[i, j] \;=\; \sum_{s=1}^{L_q} \; \max_{t=1}^{L_d} \; \langle \Qm_{i, s}, \Dm_{j, t} \rangle.$$
For each query token, find the document token it matches best (highest inner product), then sum those best matches across all query tokens. The result is a score matrix in $\mathbb{R}^{N_q \times N_d}$. Optional boolean masks drop padding tokens from the sum and from the inner max.
2Programs and the launch grid
A GPU kernel is a small function the device runs in many parallel instances. Each instance is a program in Triton's vocabulary, or a thread block in CUDA's. Writing a kernel comes down to two design choices: how many programs to launch, and what one program computes.
The MaxSim forward uses a grid of $N_q \cdot N_d$ programs, one per query–document pair. Program $(i, j)$ pulls the embeddings of query $i$ and document $j$ from HBM, accumulates a single fp32 scalar, and writes it to position $(i, j)$ of the score matrix. All $N_q \cdot N_d$ programs run concurrently on the GPU's streaming multiprocessors, with the hardware scheduling them onto whatever SMs are free.
3Inside one program
Fix a single program at $(i, j)$. It has the $i$-th query and the $j$-th document available in HBM. The most direct implementation of its scalar output looks like
$$\Sm \;=\; \Qm_{i,\,\cdot}\, \Dm_{j,\,\cdot}^{\top} \;\in\; \mathbb{R}^{L_q \times L_d}, \qquad \operatorname{score}[i,j] \;=\; \sum_{s} \max_{t} \Sm_{s, t}.$$
Building $\Sm$ even for a single pair wastes most of the values: every row of $\Sm$ contributes only one number to the final score, the row maximum. The kernel never builds it. Instead the program walks the output in tiles. Pick block sizes $B_q$ and $B_d$ (usually 32 or 64). The program loops over $\lceil L_q / B_q \rceil$ query tiles and, inside each, over $\lceil L_d / B_d \rceil$ document tiles. The largest thing that ever exists is one $B_q \times B_d$ slice of $\Sm$, held in registers for the duration of a single tile-matmul.
Each inner-loop iteration loads a fresh $\Dm$ tile, multiplies it against the resident $\Qm$ tile on the tensor cores, takes the row-wise max of the resulting $\Sm$ tile, and merges it into $\Mm$. After the inner loop finishes, $\Mm$ holds the true per-row maximum over the full $L_d$ axis, and the program adds $\sum \Mm$ to its scalar accumulator. After the outer loop finishes, that accumulator is written to HBM. One pair, one scalar, one store.
4Why the naive form is bandwidth-bound
The textbook implementation of MaxSim, the one inside PyLate's reference code, separates the two reductions:
S = einsum("nsd, mtd -> nmst", Q, D) # [Nq, Nd, Lq, Ld] ← lives in HBM
M = S.max(dim=-1).values # [Nq, Nd, Lq] ← max over t
scores = M.sum(dim=-1) # [Nq, Nd] ← sum over s
The full similarity tensor $\Sm$ is written to HBM, then read back to take the row-wise max; only the much smaller $[N_q, N_d, L_q]$ max tensor is read again for the sum. Its memory footprint is
$$\text{bytes}(\Sm) \;=\; N_q \cdot N_d \cdot L_q \cdot L_d \cdot \text{sizeof(float)}.$$
Scaling is quadratic in the sequence length. At ColPali shapes ($N_q = N_d = 64$, $L_q = L_d = 1024$, fp32) the tensor is $16$ GB; in fp16 it is $8$ GB. Training builds three or four such intermediates per step. The chart shows the same number plotted as a function of the sequence length:
Beyond the memory pressure, the naive version is bandwidth-bound: each element of $\Sm$ is written to HBM once and read back once, all for a single multiply-add per element. The arithmetic intensity is far below the H100 ridge point, so the SMs spend most of their time waiting on memory. Appendix B works the FLOP-per-byte numbers out and places both implementations on the roofline.
5Fused forward: tile, stream, accumulate
Spelling out the loop nest sketched in section 3: one Triton program per $(q\text{-batch}, d\text{-batch})$ pair, document tiles of $\text{BLOCK\_D}$ rows nested inside query tiles of $\text{BLOCK\_Q}$ rows. The full similarity tensor $\Sm$ never exists in HBM. Only the $\text{BLOCK\_Q} \times \text{BLOCK\_D}$ slice does, and it lives in fp32 registers for the duration of one tile-matmul — SRAM holds just the $\Qm$ and $\Dm$ operand tiles feeding it.
# one program per (q_batch, d_batch)
score ← 0
for q_start in range(0, Lq, BLOCK_Q):
Q_tile ← Q[q_batch, q_start : q_start + BLOCK_Q] # SRAM
m ← −∞ # [BLOCK_Q] in registers
for d_start in range(0, Ld, BLOCK_D):
D_tile ← D[d_batch, d_start : d_start + BLOCK_D] # SRAM
S_tile = Q_tile @ D_tileᵀ # tensor cores, fp32 acc
S_tile ← mask(S_tile, d_active, −∞) # fused-in masking
m ← maximum(m, rowmax(S_tile)) # online max
score += sum(m) # contribution from this Q tile
scores[q_batch, d_batch] ← score # only this scalar is written
The structure is the same outer-product tiling that FlashAttention uses, with one simplification: where FlashAttention has to maintain a running max and a running sum of exponentials for softmax, MaxSim only has a running max. The recurrence is exact and there is nothing to rescale.
Per program, SRAM holds only the operand tiles
$(\text{BLOCK\_Q} + \text{BLOCK\_D}) \cdot d$ in low precision —
Triton double-buffers them so the next $\Dm$ tile loads while the
current GEMM runs. The score tile $\Sm$ and the running max $\Mm$
live in fp32 registers, never in SRAM: tl.dot
returns its accumulator straight into the register file, and the
row-reduction happens there. At typical block sizes
($\text{BLOCK\_Q} = \text{BLOCK\_D} = 64$, $d = 128$), the operand
tiles cost 32 KiB per pipeline stage — 64–96 KiB per program at the
usual two to three stages — comfortably below the H100's 228 KiB per
SM, leaving room for multiple concurrent programs and the autotuner
to trade block size against occupancy.
Long-query chunking
For long query sequences ($L_q > 512$, e.g. image or document queries used on the query side), the outer loop in the pseudo-code above serialises $\sim\!8$ query-tile iterations inside one program, leaving the rest of the SM grid idle. But MaxSim is a sum over query tokens of a per-token max, so it decomposes exactly:
$$\text{score}[i, j] = \sum_{c=0}^{n_c - 1} \sum_{s \in \text{chunk}_c} \max_t \langle \Qm_{i, s},\, \Dm_{j, t} \rangle.$$
When $L_q > 512$, maxsim() reshapes $\Qm$ into
$N_q n_c$ chunks of 128 tokens ($n_c = \lceil L_q / 128 \rceil$), runs
the kernel over the larger batch, and sums the chunk scores back per
query — plain tensor ops, so autograd flows through unchanged. Two
gains at once: more programs on the grid, and the kernel always sees
$L_q = 128$, pinning the autotune cache to one entry instead of one
per distinct query length.
Below the $L_q = 512$ crossover the serial loop is already short and the grid already full, so the extra launches would only slow things down — short ColBERT-style queries take the un-chunked path. The same logic exempts the 4-D KD layout entirely: it already launches $N_q \cdot K$ programs, enough to fill every SM.
Step through the algorithm
Below is a toy execution with $L_q = 4$, $L_d = 12$, $d = 4$, $\text{BLOCK\_Q} = 4$ (one outer iteration) and $\text{BLOCK\_D} = 4$ (three inner iterations). Use the controls to step through; the counters on the right track HBM bytes moved versus the naive baseline.
State (current tile)
HBM traffic (one program)
6The online-max recurrence
To eliminate the score tile entirely, the inner loop replaces every value of $\Sm$ with its contribution to the per-row maximum, which we update incrementally. After consuming the $k$-th tile, the running max obeys
$$\Mm^{(k)} \;=\; \max\!\bigl(\Mm^{(k-1)},\;\; \operatorname{rowmax}(\Sm^{(k)})\bigr),$$
with $\Mm^{(0)} = -\infty$. After all tiles are consumed, $\Mm$ is the true per-row maximum and $\sum_s \Mm_s$ is the score contribution for the current $\text{BLOCK\_Q}$. Both quantities are computed in fp32 in registers; the GEMM uses bf16/fp16 operands with fp32 accumulation, matching the tensor-core native path.
Worked example
Three doc-tiles, one query row. The argmax index is recorded for backward.
| Tile | $\Sm$ row (values seen this tile) | tile max | running max after tile | argmax (global $t$) |
|---|---|---|---|---|
| 0 | [0.42, 0.11, 0.30, 0.18] | 0.42 | 0.42 | 0 |
| 1 | [0.20, 0.55, 0.05, 0.31] | 0.55 | 0.55 | 5 |
| 2 | [0.49, 0.40, 0.50, 0.22] | 0.50 | 0.55 | 5 |
Masked positions are written as $-\infty$ before the reduction,
so they cannot influence the argmax even when scores would otherwise be
negative. This is stricter than PyLate's reference, which post-multiplies
by a $0/1$ mask, and matches the masking discipline used by
flash-maxsim and FlashAttention.
7Backward: argmax-only gradients
$\max$ is sub-differentiable: only the argmax position carries a gradient, so the backward pass is simpler than for softmax. With $g \in \mathbb{R}^{N_q \times N_d}$ the upstream gradient on the score, a short chain-rule walk (Appendix D) gives:
$$\nabla_{\Qm_{i, s}} \;=\; m^q_{i,s} \cdot \sum_{j} g_{i,j} \cdot \Dm_{j,\; \operatorname{argmax}_t \langle \Qm_{i,s}, \Dm_{j,t}\rangle},$$ $$\nabla_{\Dm_{j, t}} \;=\; \sum_{(i, s)\, :\, \operatorname{argmax}[i, j, s] \,=\, t} m^q_{i,s} \cdot g_{i, j} \cdot \Qm_{i, s}.$$
To avoid recomputing the forward, the forward kernel optionally writes an $[N_q \cdot N_d, L_q]$ int32 buffer of argmax indices (4 MB for a typical training batch). $\nabla_{\Qm}$ is then embarrassingly parallel: one program per $(i, s)$ gathers $\Dm_{j, \text{argmax}}$ across $j$ and produces a single output row.
Why $\nabla_{\Dm}$ is the hard one
Picture the scatter for a fixed doc-batch index $j$. Each query-token pair $(i, s)$ contributes one row into $\nabla_{\Dm}[j, \cdot]$ at row $\operatorname{argmax}[i, j, s]$. Whether two pairs collide is a runtime property of the argmax. The kernel cannot know it statically.
grad_D[j, ·]
rows for one doc-batch. The bucket sizes are data-dependent: row
5 is a hot spot (4 source pairs land on it); most rows have 0
or 1. Both the per-row fan-in and the empty-row pattern matter
for which kernel wins.
This scatter is a parallel reduction, and the library implements both classic ways to organise one:
-
unified— source-owned (“push”). One program per query block $(i, s)$. It already holds $\Qm_{i, s}$ for $\nabla_{\Qm}$, so it pushes its $\nabla_{\Dm}$ contribution in the same pass — both gradients from a single read of $\Qm$. Colliding writers are serialised with fp32atomic_add, so the addition order varies run to run. -
lowmem— destination-owned (“pull”). One program per $\text{BLOCK\_D}$-row tile of $\nabla_{\Dm}[j, \cdot]$ — every output row still has exactly one writer. The program pulls the contributing $(i, s)$ pairs from the saved argmax through a one-hot matmul, accumulates in fp32 registers, and stores each row once, directly in the input dtype. No atomics, fixed order.
unified lets writers collide and pays in atomics and an fp32 buffer; lowmem gives each row one owner and pays in wasted FLOPs (the one-hot is mostly zeros).
lowmem roughly halves backward peak
memory.
unified · push | lowmem · pull | |
|---|---|---|
| Who writes $\nabla_{\Dm}[j, t]$ | every program whose argmax hit it | exactly one program |
| Write primitive | fp32 atomic_add | one plain store, input dtype |
| Accumulates in | HBM (full-size fp32 buffer + cast) | fp32 registers |
| Bitwise reproducible | no (atomic order varies) | yes (fixed matmul order) |
| Weak spot | contention on hot rows | wasted FLOPs $\propto L_q$ (one-hot mostly zeros) |
That last row is the entire routing logic of auto (the
default). On long-query × long-doc cross-products the one-hot waste
outweighs the atomic pain, so those go to unified.
Everything gradient-heavy — KD / hard-negative layouts and
high-contention squares ($N_q, N_d \geq 256$, $L_q \leq 64$) — goes to
lowmem, which is also the path to force when
bitwise-reproducible training matters.
All paths use Triton's stable argmax (lowest-index
tie-break) on the forward, so only the $\nabla_{\Dm}$ reduction order
distinguishes them numerically.
Backward launch-param autotuning
The backward kernels fall into two tuning regimes. The $\nabla_{\Qm}$
kernels and the unified kernel have no block-tiling
dimensions to sweep: each is one Triton program per output row streaming
a single embedding vector, so they are autotuned over
num_warps and num_stages only, keyed like the
forward so the cache holds one entry per training regime. The
lowmem $\nabla_{\Dm}$ kernel is the exception — its one-hot
matmul is block-tiled just like the forward, one program per
(slab, $\text{BLOCK\_D}$-row doc-tile) — so it draws from the forward's
config pool and sweeps $\text{BLOCK\_Q} \times \text{BLOCK\_D}$ tile
shapes alongside the launch parameters.
AGPU memory hierarchy
Modern GPUs are dramatically faster at arithmetic than at moving data. An H100 sustains roughly $990$ TFLOP/s of bf16 tensor-core throughput against $3.3$ TB/s of HBM bandwidth, a ratio of about $300$ FLOPs per byte. Any kernel that moves data more aggressively than it computes on it becomes bandwidth-bound, and the SMs sit idle waiting for memory. The trick to writing a fast kernel is therefore not faster arithmetic, but fewer round-trips to HBM. The hierarchy in play:
| Level | Size | Bandwidth | Where it lives |
|---|---|---|---|
| Registers | 256 KB / SM | ~80 TB/s | Inside each thread |
| SRAM (shared memory) | 228 KB / SM (H100) | ~19 TB/s | On-chip, per SM |
| L2 cache | 50 MB (H100) | ~5 TB/s | On-chip, chip-wide |
| HBM | 80 GB (H100) | ~3.3 TB/s | Off-chip, on the package |
SRAM is roughly six times faster than HBM, and registers are another four times faster than SRAM. A fused kernel chains several logical operations back-to-back inside SRAM and registers, writing only the final result to HBM. That is the optimisation strategy used throughout this library.
BArithmetic intensity
Whether a kernel is bound by compute or by memory is captured by a single number, the arithmetic intensity, $\mathrm{AI} = \text{FLOPs} / \text{HBM bytes moved}$. Plot achievable throughput against $\mathrm{AI}$ and you get the roofline: a bandwidth slope on the left, a compute plateau on the right, meeting at the ridge point $\mathrm{AI}^{*} = \text{peak compute} / \text{peak bandwidth}$. For an H100 in bf16, $\mathrm{AI}^{*} \approx 295$ FLOPs/byte. Below that you are memory-bound, above it compute-bound.
Scoring one $(\Qm, \Dm)$ pair is $2 L_q L_d \, d$ FLOPs of matmul plus a few cheap reductions. The two implementations differ only in what they push through HBM. The naive form materialises the score tile $\Sm$, so its traffic is dominated by the $L_q \times L_d$ surface (read + write, fp16):
$$\mathrm{AI}_{\text{naive}} \;\approx\; \frac{2 L_q L_d \, d}{4 L_q L_d} \;=\; \frac{d}{2}.$$
The fused form never writes $\Sm$; the only HBM traffic is the operand reads $\Qm$ and $\Dm$:
$$\mathrm{AI}_{\text{fused}} \;\approx\; \frac{2 L_q L_d \, d}{2 (L_q + L_d) \, d} \;=\; \frac{L_q L_d}{L_q + L_d}.$$
At ColPali shapes ($L_q = L_d = 1024$, $d = 128$), the naive form sits at about 64 FLOPs/byte, four to five times below the ridge, so the tensor cores idle while HBM works flat out. The fused form sits at about 512 FLOPs/byte, comfortably above the ridge, where the kernel is finally limited by compute. The same workload moves between regimes purely by changing what crosses HBM:
Fusing is not faster because the arithmetic is cheaper. It is the same matmul. It is faster because the matmul never has to wait for memory.
CThe constexpr bargain
A tl.constexpr kernel argument is a value the compiler
sees at compile time, so it can generate code tailored to exactly
that value. The catch: every new value triggers a fresh
compile (a multi-second stall) and adds another autotune-cache entry.
That is the bargain, and it is only worth taking when the answer to
one question is yes: does knowing the value let the compiler
generate better code?
The forward kernel answers it both ways, on two loops that look almost identical:
for q_start in tl.static_range(0, Lq, BLOCK_Q): # unrolled at compile time → Lq is constexpr
...
for d_start in range(0, Ld, BLOCK_D): # plain runtime loop → Ld is a normal argument
For $L_d$: no. The inner loop is long, so the compiler keeps it as a loop and pipelines it, loading the next $\Dm$ tile while the current one is being multiplied. The machine code is identical whether $L_d$ is 300 or 301; specialising would just recompile once per distinct document length in a ragged corpus, for nothing. $L_d$ stays a plain runtime argument.
For $L_q$: yes. The outer loop runs just $n = \lceil L_q / \text{BLOCK\_Q} \rceil$ iterations, so the compiler unrolls it: it deletes the loop and pastes $n$ copies of the body back-to-back — straight-line code, no counter, no branch. It can only do that if it knows $n$ at compile time, and $n$ comes from $L_q$. That is what makes $L_q$ constexpr.
One problem remains: query lengths float from batch to batch, and
each distinct $L_q$ would be its own compile. So the kernel shrinks
the domain. $L_q$ is rounded up to the next power of two, $\Qm$ is
zero-padded, and the query mask hides the padding — at most nine
compiled values between 16 and 4096 in normal use. (Past 4096, $L_q$
passes through unbucketed, with a one-shot warning pointing
long-context callers at maxsim_varlen.)
DDeriving the backward gradients
Section 7 quotes the closed forms for $\nabla_{\Qm}$ and $\nabla_{\Dm}$. They drop out of the forward formula in three short steps: resolve the $\max$, apply the chain rule, then read off the two cases.
Resolving the max
Once the forward has run, define the winner for each query token:
$$t^{\star}(i, j, s) \;=\; \arg\max_{t} \, \langle \Qm_{i, s},\, \Dm_{j, t}\rangle.$$
The score then collapses into a plain sum of inner products at the winning slots — the $\max$ has been resolved into a concrete index:
$$\text{score}(i, j) \;=\; \sum_{s} \langle \Qm_{i, s},\, \Dm_{j,\, t^{\star}(i, j, s)}\rangle.$$
$\max$ is non-smooth, but at any non-tied point its subgradient equals the gradient of whichever term achieves the maximum. So locally, for a perturbation small enough that no argmax flips, $\text{score}(i, j)$ is just a sum of bilinear terms in $\Qm$ and $\Dm$, and $t^{\star}$ is a constant we read out of the saved buffer.
Chain rule
Let $g_{i, j} = \partial L / \partial \text{score}(i, j)$ be the upstream gradient. For any parameter $X$:
$$\frac{\partial L}{\partial X} \;=\; \sum_{i, j} g_{i, j} \cdot \frac{\partial\, \text{score}(i, j)}{\partial X}.$$
Case 1: $\partial / \partial \Qm_{i, s, k}$
$\text{score}(i, j)$ depends on the row $\Qm_{i, s}$ only when both indices match, and only the $s$-th term in the sum contains the component $\Qm_{i, s, k}$. That term is $\langle \Qm_{i, s},\, \Dm_{j,\, t^{\star}(i, j, s)}\rangle$, so the partial is the matching entry of $\Dm$:
$$\frac{\partial\, \text{score}(i, j)}{\partial \Qm_{i, s, k}} \;=\; \Dm_{j,\, t^{\star}(i, j, s),\, k}.$$
Summing over $j$, the only free index left in the chain rule, gives $\nabla_{\Qm_{i, s}}$. Each $j$ picks exactly one row of $\Dm$ out of $L_d$, a pure gather with no collisions.
Case 2: $\partial / \partial \Dm_{j, t, k}$
$\text{score}(i, j_0)$ depends on $\Dm_{j_0, \cdot, \cdot}$ only when the doc-batch indices match. Within $\text{score}(i, j_0) = \sum_s \langle \Qm_{i, s},\, \Dm_{j_0,\, t^{\star}(i, j_0, s)}\rangle$, the entry $\Dm_{j_0, t, k}$ appears in the $s$-th term iff $t^{\star}(i, j_0, s) = t$. For each such $s$ the partial is $\Qm_{i, s, k}$, and multiple $s$'s may collide on the same $t$ — they all add:
$$\frac{\partial\, \text{score}(i, j)}{\partial \Dm_{j, t, k}} \;=\; \sum_{s\,:\, t^{\star}(i, j, s) \,=\, t} \Qm_{i, s, k}.$$
Chaining over $i$ with $j$ fixed folds the two sums into one set:
$$\frac{\partial L}{\partial \Dm_{j, t, k}} \;=\; \sum_{(i, s)\,:\, t^{\star}(i, j, s) \,=\, t} g_{i, j} \cdot \Qm_{i, s, k}.$$
That set $\{(i, s) : t^{\star}(i, j, s) = t\}$ is exactly the bucket
the lowmem kernel reduces per output row, and the same
contention pattern the unified path resolves with
atomic_add.
The mask falls out for free
The forward gates masked query tokens out of the outer sum, so masked positions contribute nothing to the score and their partials are zero on both sides. The $m^q_{i, s}$ factor in the closed-form expressions is just that gating carried through.