DeepSeek Sparse Attention

Original paper · DeepSeek-AI et al 2025

DeepSeek Sparse Attention (DSA) augments a standard transformer with a lightweight, parallel "lightning indexer." This indexer rapidly scores the entire token history and selects a small, top-kk set of the most relevant past tokens for each new query. The main, "core" attention mechanism—in DeepSeek's case, MLA—then operates only on this sparsely selected set. This approach dramatically cuts down on the compute required for long-context processing.

In DeepSeek's V3.2-Exp model, DSA is instantiated under MLA in a MQA configuration. This means each compressed, latent key-value entry is shared across all query heads. That being said, DSA's lightning indexer is compatible with any core attention variant.

Why DSA?

MLA (which we'll recap shortly) was by all means a pivotal optimization. It introduced a compressed KV cache, decoupling training and inference attention: models could leverage full MHA during training but run efficient MQA at inference. This design yields massive KV-cache savings while preserving high model quality. The problem, however, wasn't fully solved. The computation under MLA remains effectively equivalent to standard attention; the score calculation is still quadratic, O(L2)O(L^2) in sequence length.

As context lengths grow—especially in complex, agentic settings—that O(L2)O(L^2) FLOPs cost quickly becomes the dominant bottleneck. 2025 has seen a wave of linear and sparse attention variants (we previously explored NSA at DeepSeek, but stable training has seemingly proved elusive for DS). DSA represents a more pragmatic and stable intermediate. It bridges the gap by using an extremely cheap indexer (few heads, small head dimension, FP8 precision) to select just kk surviving tokens. The expensive main attention mechanism then operates only on these kk survivors, bringing the core attention complexity down to O(Lk)O(L \cdot k), where kLk \ll L.

DSA decoding cost

We'll anchor the rest of this post on the "DSA under MLA" architecture. The diagram below illustrates the two key paths: the new Lightning Indexer (the green path), which scans the history to produce index scores and a Top-kk selector, and the original Core Attention (the gray path), which now consumes only the kk selected latent KV entries. You can view this as a fast, compact search stage (the indexer) that gates a powerful but expensive attention stage (MLA); the indexer learns a compact metric space for similarity search, analogous to a learned k-NN.

DSA under MLA

To make sense of the integration, we first need to understand the baseline. We'll recap just enough of MLA to make the DSA modifications obvious, then walk through the indexer's design, connecting the math to the modeling code.

MLA

Let's first establish the MLA baseline. We enter the attention block with hidden states xRB×S×dmodelx \in \mathbb{R}^{B \times S \times d_{\text{model}}}. MLA's core idea is to factorize queries and keys into two complementary channels: a latent channel (which carries the content, or "what") and a RoPE channel (which carries the position, or "where"). Values live entirely within the latent channel.

The KV Path

First, we produce the key-value pre-activations via the "A" projection, WAKVW^{KV}_{A}. This tensor is immediately split into two parts: the latent KV (cKVc^{KV}) and the RoPE key (kRk^R).

k~A=xWAKVRB×S×(dC+dRoPE)\tilde{k}^{A} = x W^{KV}_{A} \in \mathbb{R}^{B \times S \times (d_C + d_{\mathrm{RoPE}})} k~AcKVRB×S×dCkRRB×S×dRoPE\tilde{k}^{A} \Rightarrow c^{KV} \in \mathbb{R}^{B \times S \times d_C} \oplus k^{R} \in \mathbb{R}^{B \times S \times d_{\mathrm{RoPE}}}

Here, dC=kv_lora_rankd_C=\texttt{kv\_lora\_rank} is the dimension of the compressed latent KV, and dRoPE=qk_rope_head_dimd_{\mathrm{RoPE}}=\texttt{qk\_rope\_head\_dim} is the head dimension of the decoupled RoPE key. The latent path is normalized, while the RoPE path gets its positional information.

cKVRMSNorm(cKV)kRRoPE(kR)c^{KV} \leftarrow \mathrm{RMSNorm}(c^{KV}) \qquad k^{R} \leftarrow \mathrm{RoPE}(k^{R})
# x: (B, S, d_model)
kvA = wkv_a(x)               # (B, S, d_C + d_RoPE)
cKV, kR = torch.split(kvA, [d_C, d_RoPE], dim=-1)  # (B,S,d_C), (B,S,d_RoPE)

# Apply norm and RoPE
cKV = kv_norm(cKV)
kR = apply_rotary_emb(kR, freqs_cis)   # (B,S,d_RoPE)

These two tensors are our entire key-value cache. For each token, we store just the compact content latent cKVc^{KV} and the decoupled positional key kRk^{R}.

kv_cache[:, start_pos:end_pos, :] = cKV  # (B, S, d_C)
pe_cache[:, start_pos:end_pos, :] = kR   # (B, S, d_RoPE)

The Query Path

On the query side, we do something similar. We first apply a low-rank "A" projection (WAQW^{Q}_{A}) and normalize to get a compressed latent query, cQc^Q. (This cQc^Q will be reused later by the lightning indexer, a key efficiency).

cQ=RMSNorm(xWAQ)RB×S×dQc^Q = \mathrm{RMSNorm}(x W^{Q}_{A}) \in \mathbb{R}^{B \times S \times d_Q}
# x: (B, S, d_model)
cQ  = q_norm(wq_a(x))        # (B, S, d_Q)

Here, dQ=q_lora_rankd_Q=\texttt{q\_lora\_rank} is the compressed query rank. To form the final per-head queries, we apply a second "B" projection (WBQW^{Q}_{B}) from this latent cQc^Q, which expands it into HH heads. We then split each head into its no-RoPE (qAq^A) and RoPE (qRq^R) subspaces.

q=cQWBQ (qARB×S×H×dNoPE, qRRB×S×H×dRoPE),\begin{aligned} q &= c^Q W^{Q}_{B} \ &\Rightarrow \big(q^{A} \in \mathbb{R}^{B \times S \times H \times d_{\text{NoPE}}}, \ q^{R} \in \mathbb{R}^{B \times S \times H \times d_{\text{RoPE}}}\big), \end{aligned} qRRoPE(qR)q^{R} \leftarrow \mathrm{RoPE}(q^{R})
# project to heads, then split
q_full = wq_b(cQ).view(B, S, H, d_NoPE + d_RoPE)
qA, qR = torch.split(q_full, [d_NoPE, d_RoPE], dim=-1)  # (B,S,H,d_NoPE), (B,S,H,d_RoPE)

# Apply RoPE only to the qR slice
qR = apply_rotary_emb(qR, freqs_cis)

Here, dNoPE=qk_nope_head_dimd_{\text{NoPE}}=\texttt{qk\_nope\_head\_dim} and dRoPE=qk_rope_head_dimd_{\text{RoPE}}=\texttt{qk\_rope\_head\_dim} sum to the full query-key head dimension qk_head_dim\texttt{qk\_head\_dim}.

MLA "Trick"

Now we have our queries qA,qR\langle q^{A}, q^{R} \rangle and our cached keys cKV,kR\langle c^{KV}, k^{R} \rangle. A critical component of MLA is that we never materialize the full per-head keys for the entire history. Instead, we use an algebraic trick to score directly against the compact caches.

Consider the no-RoPE score contribution. If we did materialize the key for head hh at time tt, we would take the latent ctKVRdCc^{KV}_t \in \mathbb{R}^{d_C} and apply the key-block of the "B" up-projection, let's call it WKRdNoPE×dCW_K \in \mathbb{R}^{d_{\text{NoPE}} \times d_C}.

ktNoPE=WKctKVRdNoPE.k^{\text{NoPE}}_{t} = W_K c^{KV}_t \in \mathbb{R}^{d_{\text{NoPE}}}.

The score would be the dot product with the query qAq^A:

qA,ktNoPE=qA(WKctKV)=(WKqA)ctKV.\langle q^{A}, k^{\text{NoPE}}_{t} \rangle = q^{A\top} (W_K c^{KV}_t) = (W_K^\top q^{A})^\top c^{KV}_t.

This identity is the trick! Instead of projecting all tt past keys up (WKctKVW_K c^{KV}_t), we project the single query qAq^A down (WKqAW_K^\top q^{A}) and compute the dot product in the compact latent space dCd_C. We transform the query once so it can score directly against the cKVc^{KV} cache.

In code, this query transformation q~=WKqA\tilde{q} = W_K^\top q^{A} (named q_lat) looks a bit odd, but it's just a reshape and an einsum:

# wkv_b_weight is (H * (d_NoPE + d_V), d_C)
# We view it as (H, d_NoPE + d_V, d_C)
wkv_b = wkv_b_weight.view(H, d_NoPE + d_V, d_C)

# qA: (B, 1, H, d_NoPE)
# Take the first d_NoPE rows (the key block) for each head
q_lat = torch.einsum("bshd,hdc->bshc", qA, wkv_b[:, :d_NoPE])  # (B, 1, H, d_C)

The final scores are the sum of two dot products, both operating on the raw caches:

  1. Latent Score: The transformed query q_lat dotted with the cKVc^{KV} cache.
  2. RoPE Score: The RoPE query qRq^R dotted with the kRk^R cache.
# q_lat: (B, 1, H, d_C),  kv_cache: (B, t, d_C)
# qR:    (B, 1, H, d_RoPE), pe_cache: (B, t, d_RoPE)
scores = (
    torch.einsum("bshc,btc->bsht", q_lat, kv_cache[:, :t_end]) +  # latent: \tilde{q}^T c^{KV}_t
    torch.einsum("bshr,btr->bsht", qR,    pe_cache[:, :t_end])    # RoPE:  q^R · k^R_t
) * softmax_scale

Following the softmax, we aggregate the values in latent space using the same cKVc^{KV} cache. This keeps the heaviest computation (the weighted sum) in the compact dCd_C dimension:

attn  = scores.softmax(dim=-1)                                   # (B, 1, H, t)
x_lat = torch.einsum("bsht,btc->bshc", attn, kv_cache[:, :t_end])  # (B, 1, H, d_C)

Finally, we expand this latent representation xlatx_{\text{lat}} to the full value head dimension dVd_V using the value block of the WBKVW^{KV}_B projection (the last dVd_V rows) and project back to model space.

# Use the value rows (last d_V) per head to up-project
x_head = torch.einsum("bshc,hdc->bshd", x_lat, wkv_b[:, -d_V:])  # (B, 1, H, d_V)
x_out  = wo(x_head.flatten(2))                                   # (B, 1, d_model)

This latent trick is identical to standard attention but avoids ever constructing full per-head keys, saving memory and compute during the decode loop.

DeepSeek Sparse Attention (DSA)

MLA gave us compact per-token caches (cKVc^{KV} and kRk^R) and a decode path that avoids materializing full keys. What it didn't change is how many past tokens we touch: all tt of them, making the scoring O(t)O(t).

The Lightning Indexer adds a fast, low-dimensional search stage before this scoring. It scans the entire history in a tiny, FP8-quantized space and proposes a top-kk candidate set. The expensive MLA scoring logic we just reviewed is then only run on those kk tokens.

Building the Indexer Space

The indexer builds its own compact space with HI=index_n_headsH_I = \texttt{index\_n\_heads} (e.g., 64) small heads of width dI=index_head_dimd_I = \texttt{index\_head\_dim} (e.g., 128). This space is separate from the main attention.

Indexer Query: The indexer query path starts from the same compressed query activation cQRB×S×dQc^Q \in \mathbb{R}^{B \times S \times d_Q} that MLA used. This is a key efficiency. We project it to the indexer's head space, split into NoPE and RoPE slices, and apply RoPE.

qidx=cQWBQ,idxRB×S×HI×(dI,NoPE+dI,RoPE)q^{\text{idx}} = c^Q W^{Q,\text{idx}}_{B} \in \mathbb{R}^{B \times S \times H_I \times \big(d_{I,\text{NoPE}} + d_{I,\text{RoPE}}\big)} qmix=concat(qNoPEidx, RoPE(qRoPEidx))RB×S×HI×dIq^{\text{mix}} = \mathrm{concat}\big(q^{\text{idx}}_{\text{NoPE}},\ \mathrm{RoPE}(q^{\text{idx}}_{\text{RoPE}})\big) \in \mathbb{R}^{B \times S \times H_I \times d_I}
q_idx = wq_b_index(cQ).view(B, S, H_I, dI_NoPE + dI_RoPE)   # (B,S,H_I,·)
qI_nope, qI_rope = torch.split(q_idx, [dI_NoPE, dI_RoPE], dim=-1)
qI_rope = apply_rotary_emb(qI_rope, freqs_cis)
q_mix = torch.cat([qI_nope, qI_rope], dim=-1)                # (B,S,H_I,d_I)

Indexer Key: The same process is repeated for the keys, starting from the model hidden state xx. However, the indexer key path is MQA-style (it's not split into heads), producing a single key vector per token.

kmixRB×S×dIk^{\text{mix}} \in \mathbb{R}^{B \times S \times d_I}

Hadamard Rotation: To decorrelate features and improve the numerical properties for low-precision math, both the queries and keys are rotated using a Walsh-Hadamard transform. Think of this as conditioning the vectors for a robust FP8 search.

qrot=Hadamard(qmix),krot=Hadamard(kmix)q_{\text{rot}} = \mathrm{Hadamard}\big(q^{\text{mix}}\big), \qquad k_{\text{rot}} = \mathrm{Hadamard}\big(k^{\text{mix}}\big)

Scoring and Selection

The indexer runs entirely in FP8. Both the query and key activations are quantized, and the keys are cached in FP8 along with their per-block quantization scales.

(qfp8,sq)=quant8(qrot),(kfp8,sk)=quant8(krot)(q_{\mathrm{fp8}}, s_q) = \mathrm{quant8}(q^{\text{rot}}), \qquad (k_{\mathrm{fp8}}, s_k) = \mathrm{quant8}(k^{\text{rot}})
q_fp8, q_scale = act_quant(q_rot)         # (B,S,H_I,d_I), (B,S,H_I,1 or blocks)
k_fp8, k_scale = act_quant(k_rot)         # (B,S,d_I),     (B,S,blocks)

# Update the FP8 indexer cache
k_cache[:, start:end]       = k_fp8
k_scale_cache[:, start:end] = k_scale

At decode time (S=1S=1), a specialized fused kernel performs the search. This kernel runs an FP8 GEMM between the query qfp8q_{\mathrm{fp8}} and the entire key cache kfp8,tk_{\mathrm{fp8}, \le t}. This is still an O(t)O(t) operation, but the constants are tiny (HIH_I and dId_I are small, and the math is FP8).

The kernel computes a nonnegative similarity (using a ReLU) for each head, then performs a weighted sum across the HIH_I heads to produce a single scalar "index score" for each past token:

logitsh,t=qfp8,hkfp8,t\text{logits}_{h,t} = q_{\mathrm{fp8},h} \cdot k_{\mathrm{fp8},t} logitsh,t+=max(0, logitsh,t)\text{logits}^{+}_{h,t} = \max\big(0,\ \text{logits}_{h,t}\big) scoret=h=1HIwh,logitsh,t+\text{score}_{t} = \sum_{h=1}^{H_I} w_h, \text{logits}^{+}_{h,t} index_scoret=scoretsk(t)\text{index\_score}_{t} = \text{score}_{t} \cdot s_k(t)

Where qfp8,hq_{\mathrm{fp8},h} is the FP8 query for indexer head hh, kfp8,tk_{\mathrm{fp8},t} is the FP8 key for token tt, max(0,)\max(0,\cdot) is a ReLU that discards negative correlations, whw_h is a learned scalar gate (derived from the current xx) that weights the importance of each indexer head and sk(t)s_k(t) is the per-block dequantization scale for key tt, which restores the magnitude after the FP8 dot product.

# These weights are derived from x to combine heads cheaply inside the kernel
q_weights = head_weight_proj(x).view(B, S, H_I) * inv_sqrt(H_I)

# The fused fp8_index kernel does all the steps above:
# (B,S,H_I,d_I) @ (B,t,d_I) -> (B,S,t)
# It applies the ReLU, head weighting (q_weights), and k_scales internally.
index_score = fp8_index(
    q_fp8,                     # (B,S,H_I,d_I)
    q_weights,                 # (B,S,H_I)
    k_cache[:, :t_end],        # (B,t,d_I)    FP8
    k_scale_cache[:, :t_end],  # (B,t,blocks) per-block scales
)  # → (B,S,t)

Finally, we top-kk these scores to find the indices of the kk (e.g., 2048) most relevant tokens.

The complete DSA loop is as follows:

  1. Build compact FP8 search vectors (Q and K) for the indexer.
  2. Run a fast, fused FP8 search to score all tt past tokens.
  3. Select the top-kk token indices.
  4. Pass only these kk indices to the main MLA attention layer, which performs its expensive O(k)O(k) scoring and aggregation.

By filtering the history, DSA lets MLA operate on a tiny fraction of the full context, achieving massive computational savings while maintaining high performance.