late-interaction-kernels · design walkthrough
how it works

LIK: Scoring every query–document token pair without materialising the table

Late-interaction retrieval models compare every query token to every document token. A direct implementation builds a four-dimensional similarity tensor that overflows GPU memory at modern context lengths. The kernels in the late-interaction-kernels (LIK) library compute the same result without ever storing that tensor, by streaming tiles through on-chip memory and folding the reductions into the same pass.

1The MaxSim score

MaxSim is the scoring function ColBERT-style late-interaction models use. The inputs are two batches of token-level embeddings:

  • $\Qm \in \mathbb{R}^{N_q \times L_q \times d}$: $N_q$ queries, each with $L_q$ token vectors of dimension $d$.
  • $\Dm \in \mathbb{R}^{N_d \times L_d \times d}$: same shape, for documents.

For every (query $i$, document $j$) pair, the score is:

$$\operatorname{score}[i, j] \;=\; \sum_{s=1}^{L_q} \; \max_{t=1}^{L_d} \; \langle \Qm_{i, s}, \Dm_{j, t} \rangle.$$

For each query token, find the document token it matches best (highest inner product), then sum those best matches across all query tokens. The result is a score matrix in $\mathbb{R}^{N_q \times N_d}$. Optional boolean masks drop padding tokens from the sum and from the inner max.

The problem to solve. Computing this naively builds the full $[N_q, N_d, L_q, L_d]$ similarity tensor first, then reduces. At ColPali scale ($L_q = L_d = 1024$, a batch of 64 queries × 64 documents — 4,096 scored pairs) that tensor is 8 GB in fp16. Each element is written once and read back once, so the GPU spends most of its time on HBM traffic, not compute. The kernels in this library compute the same result without ever materialising that tensor, by streaming $B_q \times B_d$ tiles through on-chip memory and folding the reductions into the same pass.

2Programs and the launch grid

A GPU kernel is a small function the device runs in many parallel instances. Each instance is a program in Triton's vocabulary, or a thread block in CUDA's. Writing a kernel comes down to two design choices: how many programs to launch, and what one program computes.

The MaxSim forward uses a grid of $N_q \cdot N_d$ programs, one per query–document pair. Program $(i, j)$ pulls the embeddings of query $i$ and document $j$ from HBM, accumulates a single fp32 scalar, and writes it to position $(i, j)$ of the score matrix. All $N_q \cdot N_d$ programs run concurrently on the GPU's streaming multiprocessors, with the hardware scheduling them onto whatever SMs are free.

j=01 23 45 67 i=0 1 2 3 0,00,10,20,30,40,50,60,7 1,01,11,21,31,41,51,61,7 2,02,12,22,32,42,52,62,7 3,03,13,23,33,43,53,63,7
$N_q = 4$, $N_d = 8$: 32 programs, one per output cell, all running in parallel. Each program produces a single fp32 number.

3Inside one program

Fix a single program at $(i, j)$. It has the $i$-th query and the $j$-th document available in HBM. The most direct implementation of its scalar output looks like

$$\Sm \;=\; \Qm_{i,\,\cdot}\, \Dm_{j,\,\cdot}^{\top} \;\in\; \mathbb{R}^{L_q \times L_d}, \qquad \operatorname{score}[i,j] \;=\; \sum_{s} \max_{t} \Sm_{s, t}.$$

Building $\Sm$ even for a single pair wastes most of the values: every row of $\Sm$ contributes only one number to the final score, the row maximum. The kernel never builds it. Instead the program walks the output in tiles. Pick block sizes $B_q$ and $B_d$ (usually 32 or 64). The program loops over $\lceil L_q / B_q \rceil$ query tiles and, inside each, over $\lceil L_d / B_d \rceil$ document tiles. The largest thing that ever exists is one $B_q \times B_d$ slice of $\Sm$, held in registers for the duration of a single tile-matmul.

One inner-loop iteration · all of S lives in registers Q tile B_q × d · D tile d × B_d = S tile = Q · D B_q × B_d row max tile_max B_q max(m, ·) m (running) B_q · fp32 m persists across inner-loop iterations · S and tile_max are recomputed each step
Q tile · loaded once per outer iter D tile · loaded once per inner iter S tile · registers only running max $m$ · fp32 registers

Each inner-loop iteration loads a fresh $\Dm$ tile, multiplies it against the resident $\Qm$ tile on the tensor cores, takes the row-wise max of the resulting $\Sm$ tile, and merges it into $\Mm$. After the inner loop finishes, $\Mm$ holds the true per-row maximum over the full $L_d$ axis, and the program adds $\sum \Mm$ to its scalar accumulator. After the outer loop finishes, that accumulator is written to HBM. One pair, one scalar, one store.

This is the entire optimisation. A naive implementation writes the full $\Sm$ tensor to HBM and reads it back for the max (the much smaller max tensor makes a further round-trip for the sum). The fused program reads $\Qm$ and $\Dm$ once, keeps every $\Sm$ slice on-chip, and writes one fp32 scalar per pair. The next section makes the cost difference concrete.

4Why the naive form is bandwidth-bound

The textbook implementation of MaxSim, the one inside PyLate's reference code, separates the two reductions:

S = einsum("nsd, mtd -> nmst", Q, D)   # [Nq, Nd, Lq, Ld]  ← lives in HBM
M = S.max(dim=-1).values               # [Nq, Nd, Lq]      ← max over t
scores = M.sum(dim=-1)                 # [Nq, Nd]          ← sum over s

The full similarity tensor $\Sm$ is written to HBM, then read back to take the row-wise max; only the much smaller $[N_q, N_d, L_q]$ max tensor is read again for the sum. Its memory footprint is

$$\text{bytes}(\Sm) \;=\; N_q \cdot N_d \cdot L_q \cdot L_d \cdot \text{sizeof(float)}.$$

Scaling is quadratic in the sequence length. At ColPali shapes ($N_q = N_d = 64$, $L_q = L_d = 1024$, fp32) the tensor is $16$ GB; in fp16 it is $8$ GB. Training builds three or four such intermediates per step. The chart shows the same number plotted as a function of the sequence length:

Naive · materialises S Fused · streams tiles 0 512 1024 1536 2048 sequence length L (L_q = L_d) 0 8 GB 16 GB 24 GB 32 GB 40 GB peak HBM for one S
S in HBM, fp16, $N_q = N_d = 64$ fused kernel — only the $[N_q, N_d]$ score matrix fp32 doubles every point; an 80 GB H100 runs out of memory near $L \approx 3125$ even in fp16

Beyond the memory pressure, the naive version is bandwidth-bound: each element of $\Sm$ is written to HBM once and read back once, all for a single multiply-add per element. The arithmetic intensity is far below the H100 ridge point, so the SMs spend most of their time waiting on memory. Appendix B works the FLOP-per-byte numbers out and places both implementations on the roofline.

5Fused forward: tile, stream, accumulate

Spelling out the loop nest sketched in section 3: one Triton program per $(q\text{-batch}, d\text{-batch})$ pair, document tiles of $\text{BLOCK\_D}$ rows nested inside query tiles of $\text{BLOCK\_Q}$ rows. The full similarity tensor $\Sm$ never exists in HBM. Only the $\text{BLOCK\_Q} \times \text{BLOCK\_D}$ slice does, and it lives in fp32 registers for the duration of one tile-matmul — SRAM holds just the $\Qm$ and $\Dm$ operand tiles feeding it.

# one program per (q_batch, d_batch)
score ← 0
for q_start in range(0, Lq, BLOCK_Q):
    Q_tile ← Q[q_batch, q_start : q_start + BLOCK_Q]      # SRAM
    m     ← −∞                                            # [BLOCK_Q] in registers
    for d_start in range(0, Ld, BLOCK_D):
        D_tile ← D[d_batch, d_start : d_start + BLOCK_D]  # SRAM
        S_tile  = Q_tile @ D_tileᵀ                        # tensor cores, fp32 acc
        S_tile ← mask(S_tile, d_active, −∞)               # fused-in masking
        m      ← maximum(m, rowmax(S_tile))               # online max
    score += sum(m)                                       # contribution from this Q tile
scores[q_batch, d_batch] ← score                           # only this scalar is written

The structure is the same outer-product tiling that FlashAttention uses, with one simplification: where FlashAttention has to maintain a running max and a running sum of exponentials for softmax, MaxSim only has a running max. The recurrence is exact and there is nothing to rescale.

Per program, SRAM holds only the operand tiles $(\text{BLOCK\_Q} + \text{BLOCK\_D}) \cdot d$ in low precision — Triton double-buffers them so the next $\Dm$ tile loads while the current GEMM runs. The score tile $\Sm$ and the running max $\Mm$ live in fp32 registers, never in SRAM: tl.dot returns its accumulator straight into the register file, and the row-reduction happens there. At typical block sizes ($\text{BLOCK\_Q} = \text{BLOCK\_D} = 64$, $d = 128$), the operand tiles cost 32 KiB per pipeline stage — 64–96 KiB per program at the usual two to three stages — comfortably below the H100's 228 KiB per SM, leaving room for multiple concurrent programs and the autotuner to trade block size against occupancy.

Long-query chunking

For long query sequences ($L_q > 512$, e.g. image or document queries used on the query side), the outer loop in the pseudo-code above serialises $\sim\!8$ query-tile iterations inside one program, leaving the rest of the SM grid idle. But MaxSim is a sum over query tokens of a per-token max, so it decomposes exactly:

$$\text{score}[i, j] = \sum_{c=0}^{n_c - 1} \sum_{s \in \text{chunk}_c} \max_t \langle \Qm_{i, s},\, \Dm_{j, t} \rangle.$$

When $L_q > 512$, maxsim() reshapes $\Qm$ into $N_q n_c$ chunks of 128 tokens ($n_c = \lceil L_q / 128 \rceil$), runs the kernel over the larger batch, and sums the chunk scores back per query — plain tensor ops, so autograd flows through unchanged. Two gains at once: more programs on the grid, and the kernel always sees $L_q = 128$, pinning the autotune cache to one entry instead of one per distinct query length.

Below the $L_q = 512$ crossover the serial loop is already short and the grid already full, so the extra launches would only slow things down — short ColBERT-style queries take the un-chunked path. The same logic exempts the 4-D KD layout entirely: it already launches $N_q \cdot K$ programs, enough to fill every SM.

Step through the algorithm

Below is a toy execution with $L_q = 4$, $L_d = 12$, $d = 4$, $\text{BLOCK\_Q} = 4$ (one outer iteration) and $\text{BLOCK\_D} = 4$ (three inner iterations). Use the controls to step through; the counters on the right track HBM bytes moved versus the naive baseline.

State (current tile)

q_start
d_start
running max m
partial sum Σm
score (final)

HBM traffic (one program)

reads
0 B
writes
0 B
total · fused
0 B
naive baseline
0 B
ratio
step 0 / 0
Q (active) D (active) S tile · ephemeral, registers only running max m consumed

6The online-max recurrence

To eliminate the score tile entirely, the inner loop replaces every value of $\Sm$ with its contribution to the per-row maximum, which we update incrementally. After consuming the $k$-th tile, the running max obeys

$$\Mm^{(k)} \;=\; \max\!\bigl(\Mm^{(k-1)},\;\; \operatorname{rowmax}(\Sm^{(k)})\bigr),$$

with $\Mm^{(0)} = -\infty$. After all tiles are consumed, $\Mm$ is the true per-row maximum and $\sum_s \Mm_s$ is the score contribution for the current $\text{BLOCK\_Q}$. Both quantities are computed in fp32 in registers; the GEMM uses bf16/fp16 operands with fp32 accumulation, matching the tensor-core native path.

Why this is bitwise simpler than FlashAttention. Online softmax has to track $\Mm$ and the running sum of exponentials $\ell$, and rescale both whenever the running max changes: the rescalers are the $\exp(\Mm^{(k-1)} - \Mm^{(k)})$ factors in FA-1. Online max has no such rescaler: $\max$ is associative and commutative, so folding the tiles in one by one — in any order — gives exactly the same result as the offline computation.

Worked example

Three doc-tiles, one query row. The argmax index is recorded for backward.

Tile$\Sm$ row (values seen this tile)tile maxrunning max after tileargmax (global $t$)
0[0.42, 0.11, 0.30, 0.18]0.420.420
1[0.20, 0.55, 0.05, 0.31]0.550.555
2[0.49, 0.40, 0.50, 0.22]0.500.555

Masked positions are written as $-\infty$ before the reduction, so they cannot influence the argmax even when scores would otherwise be negative. This is stricter than PyLate's reference, which post-multiplies by a $0/1$ mask, and matches the masking discipline used by flash-maxsim and FlashAttention.

7Backward: argmax-only gradients

$\max$ is sub-differentiable: only the argmax position carries a gradient, so the backward pass is simpler than for softmax. With $g \in \mathbb{R}^{N_q \times N_d}$ the upstream gradient on the score, a short chain-rule walk (Appendix D) gives:

$$\nabla_{\Qm_{i, s}} \;=\; m^q_{i,s} \cdot \sum_{j} g_{i,j} \cdot \Dm_{j,\; \operatorname{argmax}_t \langle \Qm_{i,s}, \Dm_{j,t}\rangle},$$ $$\nabla_{\Dm_{j, t}} \;=\; \sum_{(i, s)\, :\, \operatorname{argmax}[i, j, s] \,=\, t} m^q_{i,s} \cdot g_{i, j} \cdot \Qm_{i, s}.$$

To avoid recomputing the forward, the forward kernel optionally writes an $[N_q \cdot N_d, L_q]$ int32 buffer of argmax indices (4 MB for a typical training batch). $\nabla_{\Qm}$ is then embarrassingly parallel: one program per $(i, s)$ gathers $\Dm_{j, \text{argmax}}$ across $j$ and produces a single output row.

Why $\nabla_{\Dm}$ is the hard one

Picture the scatter for a fixed doc-batch index $j$. Each query-token pair $(i, s)$ contributes one row into $\nabla_{\Dm}[j, \cdot]$ at row $\operatorname{argmax}[i, j, s]$. Whether two pairs collide is a runtime property of the argmax. The kernel cannot know it statically.

(i=0, s=0)
(i=0, s=1)
(i=0, s=2)
(i=1, s=0)
(i=1, s=1)
(i=1, s=2)
grad_D[j, 5]4 writers
grad_D[j, 12]1 writer
grad_D[j, 18]0 writers
grad_D[j, 23]1 writer
grad_D[j, L_d−1]0 writers
Six source pairs scattering into the grad_D[j, ·] rows for one doc-batch. The bucket sizes are data-dependent: row 5 is a hot spot (4 source pairs land on it); most rows have 0 or 1. Both the per-row fan-in and the empty-row pattern matter for which kernel wins.

This scatter is a parallel reduction, and the library implements both classic ways to organise one:

  • unified — source-owned (“push”). One program per query block $(i, s)$. It already holds $\Qm_{i, s}$ for $\nabla_{\Qm}$, so it pushes its $\nabla_{\Dm}$ contribution in the same pass — both gradients from a single read of $\Qm$. Colliding writers are serialised with fp32 atomic_add, so the addition order varies run to run.
  • lowmem — destination-owned (“pull”). One program per $\text{BLOCK\_D}$-row tile of $\nabla_{\Dm}[j, \cdot]$ — every output row still has exactly one writer. The program pulls the contributing $(i, s)$ pairs from the saved argmax through a one-hot matmul, accumulates in fp32 registers, and stores each row once, directly in the input dtype. No atomics, fixed order.
unified · source-owned (push) 1 program per (i, s) (i=0, s=0) (i=1, s=4) (i=2, s=2) fp32 buffer in HBM grad_D[j, 5] grad_D[j, 12] atomic_add ×2 atomic_add colliding writes serialised by hardware addition order varies → not reproducible fp32 buffer + cast to bf16 afterwards lowmem · destination-owned (pull) saved argmax → contributors (i=0, s=0) (i=1, s=4) (i=2, s=2) → pulled by the (j, t=12…15) owner program owns tile (j, t=4…7) Σ in fp32 registers via one-hot matmul gather grad_D[j, 4…7] · bf16 store ×1 exactly one writer per row → no atomics fixed order → bitwise reproducible written once, directly in the input dtype
The same three contributions, resolved two ways: unified lets writers collide and pays in atomics and an fp32 buffer; lowmem gives each row one owner and pays in wasted FLOPs (the one-hot is mostly zeros).
Where you accumulate decides what you allocate. Accumulate in memory and the buffer must be fp32 — 4 bytes per element, plus a transient peak while the fp32 buffer and its bf16 cast coexist. Accumulate in registers and HBM only ever receives each final value once, so the kernel can write the input dtype directly. That one move is why lowmem roughly halves backward peak memory.
unified · pushlowmem · pull
Who writes $\nabla_{\Dm}[j, t]$every program whose argmax hit itexactly one program
Write primitivefp32 atomic_addone plain store, input dtype
Accumulates inHBM (full-size fp32 buffer + cast)fp32 registers
Bitwise reproducibleno (atomic order varies)yes (fixed matmul order)
Weak spotcontention on hot rowswasted FLOPs $\propto L_q$ (one-hot mostly zeros)

That last row is the entire routing logic of auto (the default). On long-query × long-doc cross-products the one-hot waste outweighs the atomic pain, so those go to unified. Everything gradient-heavy — KD / hard-negative layouts and high-contention squares ($N_q, N_d \geq 256$, $L_q \leq 64$) — goes to lowmem, which is also the path to force when bitwise-reproducible training matters.

All paths use Triton's stable argmax (lowest-index tie-break) on the forward, so only the $\nabla_{\Dm}$ reduction order distinguishes them numerically.

Backward launch-param autotuning

The backward kernels fall into two tuning regimes. The $\nabla_{\Qm}$ kernels and the unified kernel have no block-tiling dimensions to sweep: each is one Triton program per output row streaming a single embedding vector, so they are autotuned over num_warps and num_stages only, keyed like the forward so the cache holds one entry per training regime. The lowmem $\nabla_{\Dm}$ kernel is the exception — its one-hot matmul is block-tiled just like the forward, one program per (slab, $\text{BLOCK\_D}$-row doc-tile) — so it draws from the forward's config pool and sweeps $\text{BLOCK\_Q} \times \text{BLOCK\_D}$ tile shapes alongside the launch parameters.

AGPU memory hierarchy

Modern GPUs are dramatically faster at arithmetic than at moving data. An H100 sustains roughly $990$ TFLOP/s of bf16 tensor-core throughput against $3.3$ TB/s of HBM bandwidth, a ratio of about $300$ FLOPs per byte. Any kernel that moves data more aggressively than it computes on it becomes bandwidth-bound, and the SMs sit idle waiting for memory. The trick to writing a fast kernel is therefore not faster arithmetic, but fewer round-trips to HBM. The hierarchy in play:

LevelSizeBandwidthWhere it lives
Registers 256 KB / SM ~80 TB/s Inside each thread
SRAM (shared memory) 228 KB / SM (H100) ~19 TB/s On-chip, per SM
L2 cache 50 MB (H100) ~5 TB/s On-chip, chip-wide
HBM 80 GB (H100) ~3.3 TB/s Off-chip, on the package

SRAM is roughly six times faster than HBM, and registers are another four times faster than SRAM. A fused kernel chains several logical operations back-to-back inside SRAM and registers, writing only the final result to HBM. That is the optimisation strategy used throughout this library.

BArithmetic intensity

Whether a kernel is bound by compute or by memory is captured by a single number, the arithmetic intensity, $\mathrm{AI} = \text{FLOPs} / \text{HBM bytes moved}$. Plot achievable throughput against $\mathrm{AI}$ and you get the roofline: a bandwidth slope on the left, a compute plateau on the right, meeting at the ridge point $\mathrm{AI}^{*} = \text{peak compute} / \text{peak bandwidth}$. For an H100 in bf16, $\mathrm{AI}^{*} \approx 295$ FLOPs/byte. Below that you are memory-bound, above it compute-bound.

Scoring one $(\Qm, \Dm)$ pair is $2 L_q L_d \, d$ FLOPs of matmul plus a few cheap reductions. The two implementations differ only in what they push through HBM. The naive form materialises the score tile $\Sm$, so its traffic is dominated by the $L_q \times L_d$ surface (read + write, fp16):

$$\mathrm{AI}_{\text{naive}} \;\approx\; \frac{2 L_q L_d \, d}{4 L_q L_d} \;=\; \frac{d}{2}.$$

The fused form never writes $\Sm$; the only HBM traffic is the operand reads $\Qm$ and $\Dm$:

$$\mathrm{AI}_{\text{fused}} \;\approx\; \frac{2 L_q L_d \, d}{2 (L_q + L_d) \, d} \;=\; \frac{L_q L_d}{L_q + L_d}.$$

At ColPali shapes ($L_q = L_d = 1024$, $d = 128$), the naive form sits at about 64 FLOPs/byte, four to five times below the ridge, so the tensor cores idle while HBM works flat out. The fused form sits at about 512 FLOPs/byte, comfortably above the ridge, where the kernel is finally limited by compute. The same workload moves between regimes purely by changing what crosses HBM:

MEMORY-BOUND COMPUTE-BOUND ridge AI ≈ 295 naive ≈ 64 FLOPs/B fused ≈ 512 FLOPs/B 1 10 100 1000 arithmetic intensity (FLOPs / HBM byte, log scale)

Fusing is not faster because the arithmetic is cheaper. It is the same matmul. It is faster because the matmul never has to wait for memory.

CThe constexpr bargain

A tl.constexpr kernel argument is a value the compiler sees at compile time, so it can generate code tailored to exactly that value. The catch: every new value triggers a fresh compile (a multi-second stall) and adds another autotune-cache entry. That is the bargain, and it is only worth taking when the answer to one question is yes: does knowing the value let the compiler generate better code?

The forward kernel answers it both ways, on two loops that look almost identical:

for q_start in tl.static_range(0, Lq, BLOCK_Q):  # unrolled at compile time → Lq is constexpr
    ...
    for d_start in range(0, Ld, BLOCK_D):        # plain runtime loop → Ld is a normal argument

For $L_d$: no. The inner loop is long, so the compiler keeps it as a loop and pipelines it, loading the next $\Dm$ tile while the current one is being multiplied. The machine code is identical whether $L_d$ is 300 or 301; specialising would just recompile once per distinct document length in a ragged corpus, for nothing. $L_d$ stays a plain runtime argument.

For $L_q$: yes. The outer loop runs just $n = \lceil L_q / \text{BLOCK\_Q} \rceil$ iterations, so the compiler unrolls it: it deletes the loop and pastes $n$ copies of the body back-to-back — straight-line code, no counter, no branch. It can only do that if it knows $n$ at compile time, and $n$ comes from $L_q$. That is what makes $L_q$ constexpr.

One problem remains: query lengths float from batch to batch, and each distinct $L_q$ would be its own compile. So the kernel shrinks the domain. $L_q$ is rounded up to the next power of two, $\Qm$ is zero-padded, and the query mask hides the padding — at most nine compiled values between 16 and 4096 in normal use. (Past 4096, $L_q$ passes through unbucketed, with a one-shot warning pointing long-context callers at maxsim_varlen.)

range(0, Ld, BLOCK_D) — compiled as a loop one copy of the body; a branch jumps back each iteration, so the trip count can arrive at runtime — no recompiles loop body branch back · ⌈Ld / BLOCK_D⌉ times Ld = 300 or 301 → same binary tl.static_range(0, Lq, BLOCK_Q) — unrolled the compiler deletes the loop and pastes n = ⌈Lq / BLOCK_Q⌉ copies of the body back-to-back — it must know n while compiling, so Lq has to be constexpr body · copy 0 copy 1 copy 2 copy 3 q_start = 0 = BLOCK_Q = 2·BLOCK_Q = 3·BLOCK_Q no counter, no branch · each copy gets its q_start baked in as a constant (here n = 4) the price: a new Lq means a new n, hence a brand-new binary (one recompile + one autotune entry)
The same loop body, compiled two ways. The loop form never needs its trip count at compile time; the unrolled form is its trip count — change $L_q$ and you have a different binary.
Same-looking loops, opposite decisions. Unroll the short outer loop, pipeline the long inner one. Constexpr trades compilation count for code quality — take the deal only where the compiler uses the value, and where you take it, shrink the domain.

DDeriving the backward gradients

Section 7 quotes the closed forms for $\nabla_{\Qm}$ and $\nabla_{\Dm}$. They drop out of the forward formula in three short steps: resolve the $\max$, apply the chain rule, then read off the two cases.

Resolving the max

Once the forward has run, define the winner for each query token:

$$t^{\star}(i, j, s) \;=\; \arg\max_{t} \, \langle \Qm_{i, s},\, \Dm_{j, t}\rangle.$$

The score then collapses into a plain sum of inner products at the winning slots — the $\max$ has been resolved into a concrete index:

$$\text{score}(i, j) \;=\; \sum_{s} \langle \Qm_{i, s},\, \Dm_{j,\, t^{\star}(i, j, s)}\rangle.$$

$\max$ is non-smooth, but at any non-tied point its subgradient equals the gradient of whichever term achieves the maximum. So locally, for a perturbation small enough that no argmax flips, $\text{score}(i, j)$ is just a sum of bilinear terms in $\Qm$ and $\Dm$, and $t^{\star}$ is a constant we read out of the saved buffer.

Chain rule

Let $g_{i, j} = \partial L / \partial \text{score}(i, j)$ be the upstream gradient. For any parameter $X$:

$$\frac{\partial L}{\partial X} \;=\; \sum_{i, j} g_{i, j} \cdot \frac{\partial\, \text{score}(i, j)}{\partial X}.$$

Case 1: $\partial / \partial \Qm_{i, s, k}$

$\text{score}(i, j)$ depends on the row $\Qm_{i, s}$ only when both indices match, and only the $s$-th term in the sum contains the component $\Qm_{i, s, k}$. That term is $\langle \Qm_{i, s},\, \Dm_{j,\, t^{\star}(i, j, s)}\rangle$, so the partial is the matching entry of $\Dm$:

$$\frac{\partial\, \text{score}(i, j)}{\partial \Qm_{i, s, k}} \;=\; \Dm_{j,\, t^{\star}(i, j, s),\, k}.$$

Summing over $j$, the only free index left in the chain rule, gives $\nabla_{\Qm_{i, s}}$. Each $j$ picks exactly one row of $\Dm$ out of $L_d$, a pure gather with no collisions.

Case 2: $\partial / \partial \Dm_{j, t, k}$

$\text{score}(i, j_0)$ depends on $\Dm_{j_0, \cdot, \cdot}$ only when the doc-batch indices match. Within $\text{score}(i, j_0) = \sum_s \langle \Qm_{i, s},\, \Dm_{j_0,\, t^{\star}(i, j_0, s)}\rangle$, the entry $\Dm_{j_0, t, k}$ appears in the $s$-th term iff $t^{\star}(i, j_0, s) = t$. For each such $s$ the partial is $\Qm_{i, s, k}$, and multiple $s$'s may collide on the same $t$ — they all add:

$$\frac{\partial\, \text{score}(i, j)}{\partial \Dm_{j, t, k}} \;=\; \sum_{s\,:\, t^{\star}(i, j, s) \,=\, t} \Qm_{i, s, k}.$$

Chaining over $i$ with $j$ fixed folds the two sums into one set:

$$\frac{\partial L}{\partial \Dm_{j, t, k}} \;=\; \sum_{(i, s)\,:\, t^{\star}(i, j, s) \,=\, t} g_{i, j} \cdot \Qm_{i, s, k}.$$

That set $\{(i, s) : t^{\star}(i, j, s) = t\}$ is exactly the bucket the lowmem kernel reduces per output row, and the same contention pattern the unified path resolves with atomic_add.

The mask falls out for free

The forward gates masked query tokens out of the outer sum, so masked positions contribute nothing to the score and their partials are zero on both sides. The $m^q_{i, s}$ factor in the closed-form expressions is just that gating carried through.