DeepSeek Sparse Attention

Original paper · DeepSeek-AI et al 2025

DeepSeek Sparse Attention (DSA) is a mechanism that augments a standard transformer with a lightweight, parallel "lightning indexer." This indexer rapidly scores the entire token history and selects a small, top-kk set of the most relevant past tokens for each new query. The main, "core" attention mechanism—in this case, Multi-Head Latent Attention (MLA)—then operates only on this sparsely selected set. This approach dramatically cuts down on the compute required for long-context processing.

In DeepSeek's V3.2-Exp model, DSA is instantiated under MLA in a MQA configuration. This means each compressed, latent key-value entry is shared across all query heads. That being said, it should be noted that DSA's lightning indexer is compatible with any other core attention variant.

Why DSA?

MLA, which we'll recap shortly, was a significant step forward. It introduced a compressed KV cache, allowing models to leverage full MHA during training while running efficient MQA at inference. This design delivers massive KV-cache savings and preserves high model quality. However, the computation under MLA remains effectively equivalent to standard attention. The attention score calculation is still quadratic, O(L2)O(L^2) in sequence length.

As context lengths explode—especially in complex, agentic settings—the FLOPs required by this quadratic attention term become the dominant bottleneck. While 2025 has seen a wave of linear and sparse attention variants (DeepSeek previously explored NSA, but stable training proved elusive), DSA represents a pragmatic and stable intermediate. It bridges the gap by keeping the indexer extremely cheap (few heads, small head dimension, FP8 precision) and narrowing the expensive main attention mechanism to just kk surviving tokens. This yields an overall O(Lk)O(L \cdot k) complexity for the core attention term, where kLk \ll L.

DSA decoding cost

We'll anchor the rest of this post on the "DSA under MLA" architecture. The diagram below shows the two key paths:

  1. The Green Path: The new Lightning Indexer. It scans the history to produce index scores and a "Top-kk Selector."
  2. The Gray Path: The original Core Attention (MLA), which now consumes only the kk selected latent KV entries.

View this as a fast, compact search stage (the indexer) that gates a powerful but expensive attention stage (MLA).

DSA under MLA

To make sense of the integration, we first need to understand the baseline. We'll recap just enough of MLA to make the DSA modifications obvious, then walk through the indexer's design, connecting the math to the modeling code.

MLA

Let's first establish the MLA baseline. We enter the attention block with hidden states xRB×S×dmodelx \in \mathbb{R}^{B \times S \times d_{\text{model}}}. MLA's core idea is to factorize queries and keys into two complementary channels: a latent channel (which carries the content, or "what") and a RoPE channel (which carries the position, or "where"). Values live entirely within the latent channel.

The KV Path

First, we produce the key-value pre-activations via the "A" projection, WAKVW^{KV}_{A}. This tensor is immediately split into two parts: the latent KV (cKVc^{KV}) and the RoPE key (kRk^R).

k~A=xWAKVRB×S×(dC+dRoPE)\tilde{k}^{A} = x W^{KV}_{A} \in \mathbb{R}^{B \times S \times (d_C + d_{\mathrm{RoPE}})} k~AcKVRB×S×dCkRRB×S×dRoPE\tilde{k}^{A} \Rightarrow c^{KV} \in \mathbb{R}^{B \times S \times d_C} \oplus k^{R} \in \mathbb{R}^{B \times S \times d_{\mathrm{RoPE}}}

Here, dC=kv_lora_rankd_C=\texttt{kv\_lora\_rank} is the dimension of the compressed latent KV, and dRoPE=qk_rope_head_dimd_{\mathrm{RoPE}}=\texttt{qk\_rope\_head\_dim} is the head dimension of the decoupled RoPE key. The latent path is normalized, while the RoPE path gets its positional information.

cKVRMSNorm(cKV)kRRoPE(kR)c^{KV} \leftarrow \mathrm{RMSNorm}(c^{KV}) \qquad k^{R} \leftarrow \mathrm{RoPE}(k^{R})
# x: (B, S, d_model)
kvA = wkv_a(x)               # (B, S, d_C + d_RoPE)
cKV, kR = torch.split(kvA, [d_C, d_RoPE], dim=-1)  # (B,S,d_C), (B,S,d_RoPE)

# Apply norm and RoPE
cKV = kv_norm(cKV)
kR = apply_rotary_emb(kR, freqs_cis)   # (B,S,d_RoPE)

These two tensors are our entire key-value cache. For each token, we store just the compact content latent cKVc^{KV} and the decoupled positional key kRk^{R}.

kv_cache[:, start_pos:end_pos, :] = cKV  # (B, S, d_C)
pe_cache[:, start_pos:end_pos, :] = kR   # (B, S, d_RoPE)

The Query Path

On the query side, we do something similar. We first apply a low-rank "A" projection (WAQW^{Q}_{A}) and normalize to get a compressed latent query, cQc^Q. (This cQc^Q will be reused later by the lightning indexer, a key efficiency).

cQ=RMSNorm(xWAQ)RB×S×dQc^Q = \mathrm{RMSNorm}(x W^{Q}_{A}) \in \mathbb{R}^{B \times S \times d_Q}
# x: (B, S, d_model)
cQ  = q_norm(wq_a(x))        # (B, S, d_Q)

Here, dQ=q_lora_rankd_Q=\texttt{q\_lora\_rank} is the compressed query rank. To form the final per-head queries, we apply a second "B" projection (WBQW^{Q}_{B}) from this latent cQc^Q, which expands it into HH heads. We then split each head into its no-RoPE (qAq^A) and RoPE (qRq^R) subspaces.

q=cQWBQ (qARB×S×H×dNoPE, qRRB×S×H×dRoPE),\begin{aligned} q &= c^Q W^{Q}_{B} \ &\Rightarrow \big(q^{A} \in \mathbb{R}^{B \times S \times H \times d_{\text{NoPE}}}, \ q^{R} \in \mathbb{R}^{B \times S \times H \times d_{\text{RoPE}}}\big), \end{aligned} qRRoPE(qR)q^{R} \leftarrow \mathrm{RoPE}(q^{R})
# project to heads, then split
q_full = wq_b(cQ).view(B, S, H, d_NoPE + d_RoPE)
qA, qR = torch.split(q_full, [d_NoPE, d_RoPE], dim=-1)  # (B,S,H,d_NoPE), (B,S,H,d_RoPE)

# Apply RoPE only to the qR slice
qR = apply_rotary_emb(qR, freqs_cis)

Here, dNoPE=qk_nope_head_dimd_{\text{NoPE}}=\texttt{qk\_nope\_head\_dim} and dRoPE=qk_rope_head_dimd_{\text{RoPE}}=\texttt{qk\_rope\_head\_dim} sum to the full query-key head dimension qk_head_dim\texttt{qk\_head\_dim}.

MLA "Trick"

Now we have our queries qA,qR\langle q^{A}, q^{R} \rangle and our cached keys cKV,kR\langle c^{KV}, k^{R} \rangle. A critical component of MLA is that we never materialize the full per-head keys for the entire history. Instead, we use an algebraic trick to score directly against the compact caches.

Consider the no-RoPE score contribution. If we did materialize the key for head hh at time tt, we would take the latent ctKVRdCc^{KV}_t \in \mathbb{R}^{d_C} and apply the key-block of the "B" up-projection, let's call it WKRdNoPE×dCW_K \in \mathbb{R}^{d_{\text{NoPE}} \times d_C}.

ktNoPE=WKctKVRdNoPE.k^{\text{NoPE}}_{t} = W_K c^{KV}_t \in \mathbb{R}^{d_{\text{NoPE}}}.

The score would be the dot product with the query qAq^A:

qA,ktNoPE=qA(WKctKV)=(WKqA)ctKV.\langle q^{A}, k^{\text{NoPE}}_{t} \rangle = q^{A\top} (W_K c^{KV}_t) = (W_K^\top q^{A})^\top c^{KV}_t.

This identity is the trick! Instead of projecting all tt past keys up (WKctKVW_K c^{KV}_t), we project the single query qAq^A down (WKqAW_K^\top q^{A}) and compute the dot product in the compact latent space dCd_C. We transform the query once so it can score directly against the cKVc^{KV} cache.

In code, this query transformation q~=WKqA\tilde{q} = W_K^\top q^{A} (named q_lat) looks a bit odd, but it's just a reshape and an einsum:

# wkv_b_weight is (H * (d_NoPE + d_V), d_C)
# We view it as (H, d_NoPE + d_V, d_C)
wkv_b = wkv_b_weight.view(H, d_NoPE + d_V, d_C)

# qA: (B, 1, H, d_NoPE)
# Take the first d_NoPE rows (the key block) for each head
q_lat = torch.einsum("bshd,hdc->bshc", qA, wkv_b[:, :d_NoPE])  # (B, 1, H, d_C)

The final scores are the sum of two dot products, both operating on the raw caches:

  1. Latent Score: The transformed query q_lat dotted with the cKVc^{KV} cache.
  2. RoPE Score: The RoPE query qRq^R dotted with the kRk^R cache.
# q_lat: (B, 1, H, d_C),  kv_cache: (B, t, d_C)
# qR:    (B, 1, H, d_RoPE), pe_cache: (B, t, d_RoPE)
scores = (
    torch.einsum("bshc,btc->bsht", q_lat, kv_cache[:, :t_end]) +  # latent: \tilde{q}^T c^{KV}_t
    torch.einsum("bshr,btr->bsht", qR,    pe_cache[:, :t_end])    # RoPE:  q^R · k^R_t
) * softmax_scale

Following the softmax, we aggregate the values in latent space using the same cKVc^{KV} cache. This keeps the heaviest computation (the weighted sum) in the compact dCd_C dimension:

attn  = scores.softmax(dim=-1)                                   # (B, 1, H, t)
x_lat = torch.einsum("bsht,btc->bshc", attn, kv_cache[:, :t_end])  # (B, 1, H, d_C)

Finally, we expand this latent representation xlatx_{\text{lat}} to the full value head dimension dVd_V using the value block of the WBKVW^{KV}_B projection (the last dVd_V rows) and project back to model space.

# Use the value rows (last d_V) per head to up-project
x_head = torch.einsum("bshc,hdc->bshd", x_lat, wkv_b[:, -d_V:])  # (B, 1, H, d_V)
x_out  = wo(x_head.flatten(2))                                   # (B, 1, d_model)

This latent trick is identical to standard attention but avoids ever constructing full per-head keys, saving memory and compute during the decode loop.

DeepSeek Sparse Attention (DSA)

MLA gave us compact per-token caches (cKVc^{KV} and kRk^R) and a decode path that avoids materializing full keys. What it didn't change is how many past tokens we touch: all tt of them, making the scoring O(t)O(t).

The Lightning Indexer adds a fast, low-dimensional search stage before this scoring. It scans the entire history in a tiny, FP8-quantized space and proposes a top-kk candidate set. The expensive MLA scoring logic we just reviewed is then only run on those kk tokens.

Building the Indexer Space

The indexer builds its own compact space with HI=index_n_headsH_I = \texttt{index\_n\_heads} (e.g., 64) small heads of width dI=index_head_dimd_I = \texttt{index\_head\_dim} (e.g., 128). This space is separate from the main attention.

Indexer Query: The indexer query path starts from the same compressed query activation cQRB×S×dQc^Q \in \mathbb{R}^{B \times S \times d_Q} that MLA used. This is a key efficiency. We project it to the indexer's head space, split into NoPE and RoPE slices, and apply RoPE.

qidx=cQWBQ,idxRB×S×HI×(dI,NoPE+dI,RoPE)q^{\text{idx}} = c^Q W^{Q,\text{idx}}_{B} \in \mathbb{R}^{B \times S \times H_I \times \big(d_{I,\text{NoPE}} + d_{I,\text{RoPE}}\big)} qmix=concat(qNoPEidx, RoPE(qRoPEidx))RB×S×HI×dIq^{\text{mix}} = \mathrm{concat}\big(q^{\text{idx}}_{\text{NoPE}},\ \mathrm{RoPE}(q^{\text{idx}}_{\text{RoPE}})\big) \in \mathbb{R}^{B \times S \times H_I \times d_I}
q_idx = wq_b_index(cQ).view(B, S, H_I, dI_NoPE + dI_RoPE)   # (B,S,H_I,·)
qI_nope, qI_rope = torch.split(q_idx, [dI_NoPE, dI_RoPE], dim=-1)
qI_rope = apply_rotary_emb(qI_rope, freqs_cis)
q_mix = torch.cat([qI_nope, qI_rope], dim=-1)                # (B,S,H_I,d_I)

Indexer Key: The same process is repeated for the keys, starting from the model hidden state xx. However, the indexer key path is MQA-style (it's not split into heads), producing a single key vector per token.

kmixRB×S×dIk^{\text{mix}} \in \mathbb{R}^{B \times S \times d_I}

Hadamard Rotation: To decorrelate features and improve the numerical properties for low-precision math, both the queries and keys are rotated using a Walsh-Hadamard transform. Think of this as conditioning the vectors for a robust FP8 search.

qrot=Hadamard(qmix),krot=Hadamard(kmix)q_{\text{rot}} = \mathrm{Hadamard}\big(q^{\text{mix}}\big), \qquad k_{\text{rot}} = \mathrm{Hadamard}\big(k^{\text{mix}}\big)

Scoring and Selection

The indexer runs entirely in FP8. Both the query and key activations are quantized, and the keys are cached in FP8 along with their per-block quantization scales.

(qfp8,sq)=quant8(qrot),(kfp8,sk)=quant8(krot)(q_{\mathrm{fp8}}, s_q) = \mathrm{quant8}(q^{\text{rot}}), \qquad (k_{\mathrm{fp8}}, s_k) = \mathrm{quant8}(k^{\text{rot}})
q_fp8, q_scale = act_quant(q_rot)         # (B,S,H_I,d_I), (B,S,H_I,1 or blocks)
k_fp8, k_scale = act_quant(k_rot)         # (B,S,d_I),     (B,S,blocks)

# Update the FP8 indexer cache
k_cache[:, start:end]       = k_fp8
k_scale_cache[:, start:end] = k_scale

At decode time (S=1S=1), a specialized fused kernel performs the search. This kernel runs an FP8 GEMM between the query qfp8q_{\mathrm{fp8}} and the entire key cache kfp8,tk_{\mathrm{fp8}, \le t}. This is still an O(t)O(t) operation, but the constants are tiny (HIH_I and dId_I are small, and the math is FP8).

The kernel computes a nonnegative similarity (using a ReLU) for each head, then performs a weighted sum across the HIH_I heads to produce a single scalar "index score" for each past token:

logitsh,t=qfp8,hkfp8,t\text{logits}_{h,t} = q_{\mathrm{fp8},h} \cdot k_{\mathrm{fp8},t} logitsh,t+=max(0, logitsh,t)\text{logits}^{+}_{h,t} = \max\big(0,\ \text{logits}_{h,t}\big) scoret=h=1HIwh,logitsh,t+\text{score}_{t} = \sum_{h=1}^{H_I} w_h, \text{logits}^{+}_{h,t} index_scoret=scoretsk(t)\text{index\_score}_{t} = \text{score}_{t} \cdot s_k(t)

Where:

  • qfp8,hq_{\mathrm{fp8},h} is the FP8 query for indexer head hh.
  • kfp8,tk_{\mathrm{fp8},t} is the FP8 key for token tt.
  • max(0,)\max(0,\cdot) is a ReLU that discards negative correlations.
  • whw_h is a learned scalar gate (derived from the current xx) that weights the importance of each indexer head.
  • sk(t)s_k(t) is the per-block dequantization scale for key tt, which restores the magnitude after the FP8 dot product.
# These weights are derived from x to combine heads cheaply inside the kernel
q_weights = head_weight_proj(x).view(B, S, H_I) * inv_sqrt(H_I)

# The fused fp8_index kernel does all the steps above:
# (B,S,H_I,d_I) @ (B,t,d_I) -> (B,S,t)
# It applies the ReLU, head weighting (q_weights), and k_scales internally.
index_score = fp8_index(
    q_fp8,                     # (B,S,H_I,d_I)
    q_weights,                 # (B,S,H_I)
    k_cache[:, :t_end],        # (B,t,d_I)    FP8
    k_scale_cache[:, :t_end],  # (B,t,blocks) per-block scales
)  # → (B,S,t)

Finally, we take a simple topk on these scores to find the indices of the kk (e.g., 2048) most relevant tokens.

This is the complete DSA loop:

  1. Build compact FP8 search vectors (Q and K) for the indexer.
  2. Run a fast, fused FP8 search to score all tt past tokens.
  3. Select the top-kk token indices.
  4. Pass only these kk indices to the main MLA attention layer, which performs its expensive O(k)O(k) scoring and aggregation.

By filtering the history, DSA lets MLA operate on a tiny fraction of the full context, achieving massive computational savings while maintaining high performance.