Attention Mechanisms Explained
Visual guide to how attention works in transformers - from basic self-attention to modern sparse patterns
Attention Mechanisms Explained
Attention is the core innovation that powers modern AI models like GPT-4, Claude, and Gemini. Understanding how it works is key to understanding why these models are so capable—and where their limitations come from.
The attention mechanism solves a fundamental problem in sequence modeling: how can a model decide which parts of the input are relevant for predicting the next token? The answer is beautifully simple yet profoundly powerful: let the model learn to “attend” to relevant positions through a weighted sum, where the weights are computed dynamically based on the content.
This dynamic, content-based routing of information is what enables transformers to:
- Handle long-range dependencies (“The cat” and “sat” can directly interact)
- Process sequences in parallel (no sequential bottleneck)
- Learn rich contextual representations (each token gathers relevant info from all others)
The Core Idea
Traditional RNNs process sequences one token at a time, maintaining a hidden state:
Input: "The cat sat on the mat"
↓ ↓ ↓ ↓ ↓ ↓
RNN: [h1]→[h2]→[h3]→[h4]→[h5]→[h6]
Problems:
- Information from “The” is diluted by the time we reach “mat”
- Can’t process in parallel (slow training)
- Struggles with long-range dependencies
Attention solves this by letting each token directly look at all other tokens:
The cat sat on the mat
The ● ○ ○ ○ ○ ○
cat ● ● ○ ○ ○ ○
sat ● ● ● ○ ○ ○
on ● ● ● ● ○ ○
the ● ● ● ● ● ○
mat ● ● ● ● ● ●
● = attends to (looks at)
Each token can directly access information from any previous token!
Self-Attention: The Mathematical Heart
Self-attention computes a weighted combination of input representations, where the weights depend on the input itself. The full mechanism can be expressed in one equation:
Let’s unpack this step by step.
Step 1: Create Query, Key, Value Projections
For each token embedding , we create three different projections:
where are learned projection matrices.
Intuition:
- Query (): “What am I looking for?” — represents the token’s search query
- Key (): “What do I offer?” — represents the token’s content descriptor
- Value (): “What information do I provide?” — the actual information to propagate
Think of it like a database lookup: queries search for relevant keys, and matching keys return their associated values.
import numpy as np
def create_qkv(token_embedding, W_q, W_k, W_v):
"""
Transform token embedding into Q, K, V vectors
"""
query = token_embedding @ W_q # "What am I looking for?"
key = token_embedding @ W_k # "What do I contain?"
value = token_embedding @ W_v # "What do I output?"
return query, key, value
# Example: Process "cat" token
# Embedding: [0.2, 0.5, 0.1, 0.8]
token_emb = np.array([0.2, 0.5, 0.1, 0.8])
# Weight matrices (learned during training)
W_q = np.random.randn(4, 64) # Projects to 64-dim query
W_k = np.random.randn(4, 64) # Projects to 64-dim key
W_v = np.random.randn(4, 64) # Projects to 64-dim value
q, k, v = create_qkv(token_emb, W_q, W_k, W_v)
print(f"Query shape: {q.shape}") # (64,)
print(f"Key shape: {k.shape}") # (64,)
print(f"Value shape: {v.shape}") # (64,) Visual representation:
Token: "cat"
↓
[Embedding: 0.2, 0.5, 0.1, 0.8]
↓
├──→ W_q → Query: [0.3, 0.7, ..., 0.2] "Looking for: subjects"
├──→ W_k → Key: [0.1, 0.4, ..., 0.9] "I contain: animal info"
└──→ W_v → Value: [0.5, 0.2, ..., 0.6] "My meaning: feline"
Step 2: Compute Attention Scores with Scaled Dot-Product
We measure the compatibility between each query and all keys using dot products:
The scaling factor is crucial. Without it, for large , the dot products grow large in magnitude, pushing the softmax into regions with extremely small gradients (saturation).
Why does this happen? Consider the variance of a dot product. If and have components drawn from a standard normal distribution , their dot product has variance:
As increases (GPT-4 uses ), the dot products can reach values like . When these extreme values hit softmax, we get probabilities like —essentially one-hot vectors. The gradient of softmax in these regions is near-zero, causing the dreaded vanishing gradient problem.
Scaling by normalizes the variance back to 1, keeping dot products in a reasonable range (typically ) where softmax gradients remain healthy. This simple trick was key to making transformers trainable at scale.
Matrix form: For a sequence of length :
where represents how much token should attend to token .
def compute_attention_scores(Q, K):
"""
Q: (seq_len, d_k) - queries for all tokens
K: (seq_len, d_k) - keys for all tokens
Returns: (seq_len, seq_len) attention scores
"""
d_k = K.shape[-1]
# Dot product: how similar is each query to each key?
scores = Q @ K.T # Shape: (seq_len, seq_len)
# Scale by sqrt(d_k) to prevent vanishing gradients
scores = scores / np.sqrt(d_k)
return scores
# Example with 3 tokens: ["The", "cat", "sat"]
seq_len, d_k = 3, 64
Q = np.random.randn(seq_len, d_k)
K = np.random.randn(seq_len, d_k)
scores = compute_attention_scores(Q, K)
print("Attention scores:")
print(scores)
# Output (example):
# [[ 2.1 0.3 0.1] ← "The" attends mostly to itself
# [ 0.8 3.5 0.2] ← "cat" attends mostly to itself
# [ 0.4 1.2 2.9]] ← "sat" attends to "cat" and itself Visual:
Sequence: "The cat sat"
Query from "sat": [0.5, 0.2, 0.8, ...]
↓
Compare with all keys:
Key "The": [0.1, 0.9, 0.2, ...] → Score: 0.4 (low similarity)
Key "cat": [0.6, 0.3, 0.7, ...] → Score: 1.2 (high similarity!)
Key "sat": [0.5, 0.2, 0.9, ...] → Score: 2.9 (very high)
Result: "sat" pays most attention to itself and "cat"
Step 3: Apply Softmax to Get Attention Weights
We convert scores to probabilities using softmax:
This ensures:
- All weights are positive:
- Weights sum to 1:
- Higher scores → higher weights (exponential emphasis)
Causal masking (for GPT-style models): We mask future positions by setting for before softmax. This prevents the model from “cheating” by looking ahead.
Convert scores to probabilities:
def apply_causal_mask_and_softmax(scores):
"""
Apply causal mask (can't attend to future tokens)
Then convert to probabilities with softmax
"""
seq_len = scores.shape[0]
# Create causal mask: upper triangle = -inf
mask = np.triu(np.ones((seq_len, seq_len)) * -1e9, k=1)
# [[ 0, -inf, -inf],
# [ 0, 0, -inf],
# [ 0, 0, 0]]
masked_scores = scores + mask
# Softmax: convert to probabilities
attention_weights = softmax(masked_scores, axis=-1)
return attention_weights
def softmax(x, axis=-1):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
# Example output:
# [[1.00, 0.00, 0.00], ← "The" only attends to itself
# [0.35, 0.65, 0.00], ← "cat" attends 35% to "The", 65% to itself
# [0.15, 0.30, 0.55]] ← "sat" attends to all 3 tokens Visual probability distribution:
Token: "sat"
Attention weights:
The: ███░░░░░░░ (15%)
cat: ██████░░░░ (30%)
sat: ███████████ (55%)
↑
Highest attention to itself
Step 4: Weighted Sum of Values
The final step computes the attention output as a weighted combination of value vectors:
where are the attention weights from Step 3. This creates a new representation for token that incorporates information from all other tokens, weighted by their relevance.
In matrix form:
This is the complete attention mechanism in one beautiful equation!
def apply_attention(attention_weights, V):
"""
attention_weights: (seq_len, seq_len) - probabilities
V: (seq_len, d_v) - value vectors for all tokens
Returns: (seq_len, d_v) - attention output
"""
# Weighted sum: each token is mixture of all value vectors
output = attention_weights @ V
return output
# For token "sat" with weights [0.15, 0.30, 0.55]:
# output = 0.15 * V["The"] + 0.30 * V["cat"] + 0.55 * V["sat"]
# This gives "sat" a representation that includes information
# from "The" (15%), "cat" (30%), and itself (55%) Complete Attention Formula
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Complete attention mechanism in one function
"""
d_k = Q.shape[-1]
# 1. Compute scores
scores = (Q @ K.T) / np.sqrt(d_k)
# 2. Apply mask (for causal attention)
if mask is not None:
scores = scores + mask
# 3. Softmax
attention_weights = softmax(scores, axis=-1)
# 4. Weighted sum
output = attention_weights @ V
return output, attention_weights
# Usage for full sequence:
seq = ["The", "cat", "sat", "on", "mat"]
Q, K, V = create_qkv_for_sequence(seq) # All tokens at once
output, weights = scaled_dot_product_attention(Q, K, V)
# Result: each token's output contains contextual information
# from all previous tokens it attended to Multi-Head Attention: Parallel Attention Streams
Single-head attention has a fundamental limitation: it can only learn one type of relationship. What if we want to simultaneously capture:
- Syntactic dependencies (“what word modifies what”)
- Semantic relationships (“what concepts are related”)
- Positional patterns (“what typically comes after what”)
- Coreference resolution (“what pronouns refer to what nouns”)
The solution: multi-head attention runs attention mechanisms in parallel, each learning different aspects of the relationships.
where each head is:
Each head has its own projection matrices, allowing it to learn different attention patterns. The outputs are concatenated and projected through .
Parameter efficiency: Instead of using full dimensions for each head, we typically split:
So 8 heads with 512-dim model use 64-dim per head. This keeps parameter count similar to single-head attention!
Why does this work? Each head operates in a lower-dimensional subspace ( instead of ) but learns specialized patterns. Head 1 might learn subject-verb agreement, Head 2 might learn entity relationships, Head 3 might track positional patterns. Together, they capture richer representations than any single head could.
Computational cost: Multi-head attention with heads costs the same as single-head! We’re trading depth () for breadth ( heads). The matrix multiplications are:
Identical to single-head attention, but with richer learned patterns.
class MultiHeadAttention:
def __init__(self, d_model=512, num_heads=8):
self.num_heads = num_heads
self.d_model = d_model
self.d_k = d_model // num_heads # 512 / 8 = 64
# Separate Q, K, V for each head
self.W_q = [np.random.randn(d_model, self.d_k)
for _ in range(num_heads)]
self.W_k = [np.random.randn(d_model, self.d_k)
for _ in range(num_heads)]
self.W_v = [np.random.randn(d_model, self.d_k)
for _ in range(num_heads)]
# Final projection
self.W_o = np.random.randn(d_model, d_model)
def forward(self, x):
batch, seq_len, d_model = x.shape
heads_output = []
# Process each head
for i in range(self.num_heads):
Q = x @ self.W_q[i]
K = x @ self.W_k[i]
V = x @ self.W_v[i]
head_out, _ = scaled_dot_product_attention(Q, K, V)
heads_output.append(head_out)
# Concatenate all heads
multi_head = np.concatenate(heads_output, axis=-1)
# Final projection
output = multi_head @ self.W_o
return output Sparse Attention Patterns
For long contexts, full attention is too expensive. Modern models use sparse patterns:
1. Sliding Window
Sequence: [1, 2, 3, 4, 5, 6, 7, 8]
Window size: 3
Token 5 can attend to:
... [3, 4, 5, 6, 7] ...
└──window──┘
Not [1, 2, 8] (outside window)
Complexity: O(n * window_size) instead of O(n²)
2. Global + Local
Global tokens: [START, SECTION1, SECTION2, END]
(attend to everything, everyone attends to them)
Regular tokens: sliding window + global tokens
Example attention pattern:
1 2 3 G 4 5 6 G 7 8
1 [● ● ○ ● ○ ○ ○ ○ ○ ○] local + global
2 [● ● ● ● ○ ○ ○ ○ ○ ○]
3 [○ ● ● ● ● ○ ○ ○ ○ ○]
G [● ● ● ● ● ● ● ● ● ●] global attends all
4 [○ ○ ○ ● ● ● ● ● ○ ○]
...
● = attends to
G = global token
3. Block-Sparse
Divide sequence into blocks, attend within blocks + cross-block:
Block 1: [1,2,3,4]
Block 2: [5,6,7,8]
Block 3: [9,10,11,12]
Token 6 attends to:
- All of Block 2 (local)
- Summary tokens from Block 1, 3 (cross-block)
def sliding_window_attention(Q, K, V, window_size=512):
"""
Each token only attends to window_size neighbors
"""
seq_len = Q.shape[0]
output = np.zeros_like(Q)
for i in range(seq_len):
# Define window
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2)
# Attention only within window
q_i = Q[i:i+1] # Current token query
k_window = K[start:end] # Keys in window
v_window = V[start:end] # Values in window
scores = (q_i @ k_window.T) / np.sqrt(Q.shape[-1])
weights = softmax(scores)
output[i] = weights @ v_window
return output
# Memory savings:
# Full attention: 200K * 200K = 40 billion elements
# Window attention: 200K * 512 = 102 million elements
# Reduction: 390x less memory! Why Attention Works So Well
- Parallel processing: All tokens computed simultaneously (vs RNN sequential)
- Long-range dependencies: Direct connection between distant tokens
- Learned patterns: Model learns what to attend to
- Context-aware: Each token’s representation includes relevant context
- Flexible: Multi-head captures different relationship types
Limitations
- Quadratic complexity: O(n²) for full attention
- Lost in the middle: Long contexts can dilute information
- No inherent position: Needs position encodings added
- Expensive inference: Large KV cache for long contexts
Resources & Further Reading
Foundational Papers
Original Attention Papers
- Attention is All You Need (Vaswani et al., 2017)
- Introduces scaled dot-product attention and multi-head attention
- Section 3.2 covers the attention mechanism in detail
- Neural Machine Translation by Jointly Learning to Align and Translate (Bahdanau et al., 2014)
- First attention mechanism for sequence-to-sequence models
Attention Variants
- FlashAttention: Fast and Memory-Efficient Exact Attention (Dao et al., 2022)
- IO-aware algorithm making attention 2-4× faster
- Longformer: The Long-Document Transformer (Beltagy et al., 2020)
- Sparse attention patterns for long sequences
- BigBird: Transformers for Longer Sequences (Zaheer et al., 2020)
- Random, window, and global attention combined
Implementation Resources
- PyTorch Attention Tutorial: Official implementation guide
- Flash Attention: Optimized CUDA kernels
- xFormers: Facebook’s efficient attention implementations
- Hugging Face: Attention implementations
Related Technical Guides
- Transformer Architecture → - Complete architecture overview
- Position Encodings → - How transformers understand order
- Long-Context Architecture → - Scaling to longer sequences
- KV Cache Optimization → - Memory-efficient attention
- Inference Optimization → - Making attention faster
Educational Content
Visualizations
- The Illustrated Transformer - Jay Alammar’s visual guide
- Attention? Attention! - Lilian Weng’s comprehensive overview
- Transformer Explainer - Interactive tool
Video Lectures
Advanced Topics
Efficient Attention
- Linear Attention - Reducing complexity to O(n)
- Grouped Query Attention - Reducing KV cache memory
- Multi-Query Attention - Sharing keys and values
Sparse Attention Patterns
- Sparse Transformers - Factorized attention patterns
- Reformer - Locality-sensitive hashing for attention
Last updated: December 2025