How do you implement self-attention and multi-head attention from scratch in an interview?
Updated June 18, 2026 · 8 min read · Crack ML Interview
Implementing attention from scratch is one of the most common ML coding questions at OpenAI, Anthropic, and Meta. Win it by narrating shapes at every step: project the input into Q, K, V; compute scaled dot-product scores as QK^T divided by the square root of the head dimension; apply the causal or padding mask before softmax; weight V by the attention probabilities; then reshape heads back and apply the output projection. Always state the O(n squared d) time and O(n squared) memory complexity, handle the masking edge case explicitly, and write numerically stable softmax. Candidates who annotate tensor dimensions and explain why the scaling factor exists consistently outscore those who only produce running code.
Scaled Dot-Product Attention: The Core You Must Get Right First
Build it bottom-up and narrate every tensor shape
Start from a single-head version so the interviewer sees you understand the mechanics before adding complexity. Given queries Q, keys K, and values V each of shape (batch, seq_len, d_k), compute scores as Q matmul K transposed on its last two dimensions, producing a (batch, seq_len, seq_len) matrix. Divide the scores by the square root of d_k. Apply softmax over the last dimension to get attention probabilities that sum to one across keys. Finally matmul the probabilities with V to produce a (batch, seq_len, d_v) output. State the shape after every line out loud; shape narration is the single strongest signal that you have actually built this before.
Explain why the scaling factor and the softmax dimension matter
The division by the square root of d_k prevents the dot products from growing large in magnitude as dimension increases, which would push softmax into saturated regions where gradients vanish. If asked why, explain that for random query and key vectors with unit variance components, the dot product has variance proportional to d_k, so scaling normalizes it back to roughly unit variance. The softmax must be applied over the key dimension, not the query dimension, because each query distributes its attention across all keys. Getting the softmax axis wrong is the most common silent bug interviewers watch for.
Handle masking before softmax, not after
Causal masking for decoders and padding masking for variable-length batches must be applied to the raw scores before softmax by setting masked positions to negative infinity, so they receive zero probability after the exponential. A frequent mistake is zeroing out probabilities after softmax, which leaves the remaining weights un-normalized. Build the causal mask with an upper-triangular matrix of ones above the diagonal, then masked-fill the scores at those positions with negative infinity. Mention that in fp16 you use a large finite negative number rather than literal negative infinity to avoid NaN propagation.
Extending to Multi-Head Attention
Project, split into heads, attend, then merge
Multi-head attention applies three linear projections to the input to produce Q, K, V of shape (batch, seq_len, d_model), then reshapes each into (batch, num_heads, seq_len, d_head) where d_head equals d_model divided by num_heads. Run scaled dot-product attention independently per head in parallel, which the reshape handles automatically via broadcasting over the head dimension. After attention, transpose and reshape the heads back to (batch, seq_len, d_model) and apply a final output linear projection. The key insight to verbalize is that multiple heads let the model attend to different representation subspaces simultaneously at no extra asymptotic cost beyond the single-head version.
State complexity and the practical bottleneck
Time complexity is O(seq_len squared times d_model) because the score matrix is seq_len by seq_len and each entry costs a d_head dot product across all heads. Memory complexity is O(num_heads times seq_len squared) for storing the attention matrix, which is the dominant cost for long contexts and the reason FlashAttention exists: it avoids materializing the full score matrix by computing attention in tiles with online softmax. Mentioning FlashAttention and the quadratic memory wall shows you understand why this primitive is a production bottleneck, not just an academic exercise.
Common Follow-Ups and How to Answer Them
Grouped-query and multi-query attention
A frequent senior-level follow-up asks how to reduce KV cache memory during inference. The answer is multi-query attention, where all query heads share a single key and value head, and grouped-query attention, where query heads are partitioned into groups that each share one KV head. Both shrink the KV cache proportionally to the reduction in KV heads, trading a small quality cost for large memory and bandwidth savings during autoregressive decoding. Modern open models like Llama 3 use grouped-query attention specifically for this reason.
Why positional information must be added separately
Self-attention is permutation-invariant: shuffling the input tokens shuffles the output identically because attention has no inherent notion of order. This is why transformers add positional encodings, either learned absolute embeddings, sinusoidal encodings, or rotary position embeddings applied to Q and K. If the interviewer asks why attention needs position information when RNNs do not, the answer is that RNNs encode order through sequential processing, while attention processes all positions at once and must be told their order explicitly.
Attention Implementation Checklist: Steps, Shapes, and Pitfalls
| Step | Operation | Output Shape | Common Mistake |
|---|---|---|---|
| 1. Project | Linear layers for Q, K, V | (B, L, d_model) | Forgetting separate projections per Q/K/V |
| 2. Split heads | Reshape to (B, H, L, d_head) | (B, H, L, d_head) | Wrong head/dim split order |
| 3. Scores | Q @ K^T / sqrt(d_head) | (B, H, L, L) | Omitting the scaling factor |
| 4. Mask | Set masked scores to -inf | (B, H, L, L) | Masking after softmax instead of before |
| 5. Softmax | Softmax over last (key) dim | (B, H, L, L) | Applying softmax over the query axis |
| 6. Weight values | Probs @ V | (B, H, L, d_head) | Matmul with K instead of V |
| 7. Merge heads | Transpose + reshape to (B, L, d_model) | (B, L, d_model) | Reshape without transposing first |
| 8. Output proj | Final linear layer | (B, L, d_model) | Skipping the output projection |
Who this is for
PyTorch-comfortable engineer who has only used nn.MultiheadAttention
Profile: Builds and fine-tunes transformer models using high-level library modules, but has never written the attention computation by hand and is fuzzy on the internal tensor shapes.
Pain points: Freezes when asked to implement attention without library helpers, mixes up the head-splitting reshape, and cannot explain why scores are divided by the square root of the head dimension.
Strategy: Hand-write scaled dot-product attention and multi-head attention five times from a blank file, annotating every shape in a comment. Use a runnable environment like Crack ML Interview LeanCode to verify outputs against nn.MultiheadAttention. Memorize the eight-step checklist and the scaling-factor justification so both become automatic under pressure.
Research scientist who knows the math but rarely codes it timed
Profile: Can derive attention gradients on a whiteboard and has read FlashAttention deeply, but does most modeling work at the architecture level and rarely writes the primitive from scratch in twenty minutes.
Pain points: Produces correct but slow code, gets caught on PyTorch broadcasting and transpose semantics, and over-explains the theory while running low on time to finish a working implementation.
Strategy: Practice timed implementation with a hard twenty-minute cap, writing the working code first and reserving theory for follow-up. Lead with the implementation, then offer the FlashAttention and grouped-query depth as bonus when invited. This sequencing converts deep knowledge into interview points instead of time overruns.
FAQ
Q: Should I implement attention in NumPy or PyTorch in an interview?
A: Default to PyTorch unless the interviewer specifies otherwise, since it matches production reality and makes masking and matmul cleaner. If you implement in NumPy, be ready to explain that there is no autograd, so you would only be coding the forward pass unless backward is explicitly requested.
Q: Do I need to implement the backward pass for attention?
A: Usually no. Most interviews ask only for the forward pass and rely on autograd for gradients. If asked to discuss the backward pass, explain that gradients flow through the softmax and the two matmuls via the chain rule, and that FlashAttention recomputes intermediate values in the backward pass to save memory.
Q: How long should a clean attention implementation take?
A: A fluent candidate writes correct multi-head attention with masking in twelve to eighteen minutes, leaving time for complexity narration and follow-up questions. If it takes you longer than twenty-five minutes in practice, you have not yet built the muscle memory and should keep drilling the primitive.
Want to practice with real, verified ML interview questions from top companies?
Browse the question bank