How Long-Context Models Work: Technical Architecture
Deep dive into the technical innovations that enable models like Claude, Kimi, and GPT-4 to handle 100K+ token contexts
How Long-Context Models Work
Processing 100,000+ tokens in a single prompt was once impossible. Standard transformers struggle with sequences beyond 8K tokens due to fundamental computational and memory constraints. So how do models like Claude 3 (200K), Kimi (200K), and GPT-4 Turbo (128K) handle such massive contexts?
This deep-dive explains the technical innovations that make long-context AI possible.
The challenge is brutal: attention complexity scales as where is sequence length. For 200K tokens, this means computing a 200K × 200K attention matrix—40 billion elements! At FP16 precision, that’s 80GB per layer just for attention scores, before KV cache or activations.
Yet Claude 3 and GPT-4 Turbo handle these massive contexts in real-time. The secret lies in a combination of algorithmic innovations (sparse attention, linear attention approximations), architectural improvements (sliding windows, hierarchical structures), and systems-level optimizations (FlashAttention, efficient KV cache management).
The Core Problem: Quadratic Complexity
Standard transformer attention has O(n²) complexity, where n is the sequence length.
The bottleneck is , which produces an matrix. For each element:
Computing all elements requires operations. Memory is to store the attention matrix.
Scaling nightmare:
- 2K tokens: 4M attention values (16MB at FP32)
- 8K tokens: 64M values (256MB)
- 32K tokens: 1B values (4GB)
- 128K tokens: 16B values (64GB)
- 200K tokens: 40B values (160GB) ← Impossible without optimization!
import numpy as np
def standard_attention(Q, K, V):
"""
Standard scaled dot-product attention
Complexity: O(n²) in sequence length
"""
# Q, K, V shape: (batch, seq_len, d_model)
d_k = Q.shape[-1]
# This matrix multiplication is O(n²)
scores = Q @ K.transpose(-2, -1) / np.sqrt(d_k)
# scores shape: (batch, seq_len, seq_len)
attention_weights = softmax(scores, dim=-1)
output = attention_weights @ V
return output
# Memory requirements for attention matrix
seq_lengths = [2048, 8192, 32768, 131072, 200000]
for n in seq_lengths:
# Attention matrix size in GB (float32)
memory_gb = (n * n * 4) / (1024**3)
print(f"Sequence {n:6d}: {memory_gb:8.2f} GB just for attention matrix")
# Output:
# Sequence 2048: 0.02 GB
# Sequence 8192: 0.25 GB
# Sequence 32768: 4.00 GB
# Sequence 131072: 64.00 GB
# Sequence 200000: 148.77 GB ← Impossible! At 200K tokens, a single attention operation would require ~150GB just for the attention matrix—before even considering the KV cache or model weights.
Why is this so catastrophic? Consider the GPU memory hierarchy:
GPU Memory Hierarchy (A100):
Registers: 20 MB ← 1000 GB/s (fastest)
L2 Cache: 40 MB ← 500 GB/s
HBM (VRAM): 80 GB ← 1,935 GB/s (main bottleneck)
Host RAM: 512 GB ← 25 GB/s (via PCIe)
Disk: ∞ ← 5 GB/s (unusable for inference)
The 150GB attention matrix doesn’t fit in the 80GB A100 memory. Options:
- Multi-GPU: Split across GPUs, but inter-GPU bandwidth is 300 GB/s (NVLink) or 25 GB/s (PCIe)—far slower than HBM
- Recomputation: Don’t store the attention matrix, recompute in backward pass—but this doubles compute
- Approximation: Use sparse or linear-time alternatives—sacrifices some quality
All successful long-context models choose option 3: approximate attention cleverly. The art is preserving model quality while slashing memory by 10-1000x.
Solution 1: Sparse Attention Patterns
Instead of computing all attention scores, compute only a sparse subset. If each token attends to only other tokens:
For (constant window), this reduces 200K² = 40B computations to 200K × 512 = 100M—a 400× reduction!
Sliding Window Attention
Each token attends only to a fixed-size local window. For position with window size :
Trade-offs:
- ✅ Linear complexity:
- ✅ Good for local dependencies
- ❌ Cannot capture long-range dependencies directly
- ❌ Information must propagate through layers
def sliding_window_attention(Q, K, V, window_size=512):
"""
Each token attends only to window_size neighbors
Complexity: O(n * window_size) ≈ O(n)
"""
batch, seq_len, d_model = Q.shape
# Create attention mask: 1s in window, 0s elsewhere
mask = create_sliding_window_mask(seq_len, window_size)
# mask shape: (seq_len, seq_len)
# Sparse: only window_size * seq_len non-zero entries
scores = (Q @ K.transpose(-2, -1)) / np.sqrt(d_model)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = softmax(scores, dim=-1) @ V
return attention
# Example: LongFormer-style attention
def longformer_attention(Q, K, V, window_size=512):
"""
Sliding window + global attention on special tokens
"""
local_attn = sliding_window_attention(Q, K, V, window_size)
# Global tokens attend to everything
global_attn = compute_global_attention(Q, K, V, global_token_ids)
return combine_attentions(local_attn, global_attn) Used by: Longformer, BigBird, Mistral 7B
Trade-off: Reduces ability to relate distant tokens (though global attention helps)
Axial Attention
Decompose 2D attention into row and column attention:
def axial_attention(x, axis):
"""
Attend along one axis at a time
Complexity: O(n * sqrt(n)) for 2D decomposition
"""
if axis == 0:
# Attend along rows
return attention(rearrange(x, 'b (h w) d -> b h w d'))
else:
# Attend along columns
return attention(rearrange(x, 'b (h w) d -> b w h d'))
# Reshape sequence into 2D grid
seq_len = 65536
grid_size = 256 # sqrt(65536)
# Two axial attention passes ≈ 2 * (seq_len * grid_size)
# = 2 * (65536 * 256) ≈ 33M ops
# vs standard attention: 65536² ≈ 4.3B ops Solution 2: Efficient Position Encoding
Standard absolute position embeddings break down at long contexts. Modern models use Rotary Position Embeddings (RoPE), which enable extrapolation to longer sequences.
The problem with absolute positions: Learned position embeddings create a lookup table:
This has fundamental limitations:
- Fixed maximum length: Can’t handle positions beyond
- No generalization: Position 10,000 is unrelated to position 9,999 in the embedding space
- Poor extrapolation: Model trained on 4K context fails at 8K
RoPE solution: Instead of adding position info to embeddings, rotate the query and key vectors by an angle proportional to their position. The rotation difference naturally encodes relative position.
For a 2D subspace with position , the rotation matrix is:
The magic: when computing attention between positions and :
The score depends only on the relative distance ! This relative encoding enables extrapolation:
- Trained on 4K context (positions 0-4095)
- Model learns “attention decays with distance”
- Inference on 32K context (positions 0-32767) works because the relative pattern holds
Extrapolation trick: To handle much longer contexts, interpolate the rotation frequencies:
This “compresses” the position space, allowing 4K-trained models to handle 32K+ contexts with minimal degradation.
import torch
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
"""
RoPE: Rotate query/key vectors based on position
Enables length extrapolation beyond training context
"""
# Split features into pairs
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def rotate_half(x):
"""Helper: rotate features for RoPE"""
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat((-x2, x1), dim=-1)
# RoPE advantages:
# 1. Relative position encoding (better generalization)
# 2. Decays attention with distance naturally
# 3. Can extrapolate to longer sequences than training
# 4. No learned parameters
# Example: Extend context at inference
def extend_rope_context(original_context=4096, new_context=32768):
scale = new_context / original_context
# Interpolate position embeddings
return scale_rope_frequencies(scale) Used by: LLaMA, GPT-NeoX, Kimi, most modern LLMs
Benefit: Enables context length extrapolation—model can handle longer sequences than seen during training
Solution 3: FlashAttention
Memory-efficient attention implementation that fuses operations:
# Standard attention (memory inefficient)
def standard_attention(Q, K, V):
S = Q @ K.T # Materialize full attention matrix
P = softmax(S)
O = P @ V
return O
# FlashAttention (conceptual)
def flash_attention(Q, K, V, block_size=128):
"""
Tile computation to fit in fast SRAM
Never materialize full attention matrix
"""
seq_len = Q.shape[1]
output = torch.zeros_like(Q)
# Process in blocks that fit in SRAM
for i in range(0, seq_len, block_size):
for j in range(0, seq_len, block_size):
# Load blocks into SRAM
Q_block = Q[:, i:i+block_size]
K_block = K[:, j:j+block_size]
V_block = V[:, j:j+block_size]
# Compute attention for this block
S_block = Q_block @ K_block.T
P_block = softmax(S_block)
O_block = P_block @ V_block
# Accumulate result
output[:, i:i+block_size] += O_block
return output
# Benefits:
# - 2-4x faster than standard attention
# - O(N) memory instead of O(N²)
# - Enables longer contexts with same hardware Used by: Most production LLMs (GPT-4, Claude, Kimi likely use FlashAttention 2)
Impact: Makes 100K+ contexts feasible on existing hardware
Solution 4: KV Cache Optimization
During generation, cache key and value matrices to avoid recomputation:
class KVCache:
"""
Cache key and value matrices during autoregressive generation
Reduces O(n²) to O(n) for generation
"""
def __init__(self, max_seq_len, num_layers, d_model):
self.max_seq_len = max_seq_len
# Allocate cache: (num_layers, 2, max_seq_len, d_model)
self.cache = torch.zeros(num_layers, 2, max_seq_len, d_model)
self.seq_len = 0
def update(self, layer_idx, k, v):
"""Add new keys/values to cache"""
batch_size, new_tokens, _ = k.shape
# Store in cache
start = self.seq_len
end = start + new_tokens
self.cache[layer_idx, 0, start:end] = k
self.cache[layer_idx, 1, start:end] = v
self.seq_len = end
# Return full cached K, V
return (
self.cache[layer_idx, 0, :end],
self.cache[layer_idx, 1, :end]
)
# Memory calculation for 200K context
def calculate_kv_cache_memory(
context_length=200000,
num_layers=80,
d_model=8192,
num_heads=64,
precision="float16"
):
bytes_per_element = 2 if precision == "float16" else 4
memory_bytes = (
2 * # K and V
num_layers *
context_length *
d_model *
bytes_per_element
)
memory_gb = memory_bytes / (1024**3)
print(f"KV Cache: {memory_gb:.2f} GB")
return memory_gb
# At 200K context with 80 layers:
calculate_kv_cache_memory()
# Output: KV Cache: 48.83 GB
# Optimization: Grouped-Query Attention (GQA)
# Share KV across query heads to reduce memory
def gqa_memory_savings(num_query_heads=64, num_kv_heads=8):
standard_kv = num_query_heads
gqa_kv = num_kv_heads
savings = (standard_kv - gqa_kv) / standard_kv
print(f"GQA Memory Savings: {savings:.1%}")
return savings
gqa_memory_savings()
# Output: GQA Memory Savings: 87.5% Solution 5: Training Strategies
Progressive Length Training
Train on increasingly longer sequences:
Stage 1: 4K context - 80% of training
Stage 2: 16K context - 15% of training
Stage 3: 64K context - 4% of training
Stage 4: 200K context - 1% of training
Why: Most real-world usage is shorter sequences; focus compute there
Length Extrapolation
Use techniques to extend context beyond training:
# YaRN: Yet another RoPE extensioN
def yarn_scaling(original_scale, target_context, trained_context):
"""
Interpolate RoPE frequencies for longer contexts
"""
scale = target_context / trained_context
# Apply non-uniform scaling: scale less for low frequencies
low_freq_factor = 1.0
high_freq_factor = scale
# Interpolate based on frequency
return interpolate_frequencies(
low_freq_factor,
high_freq_factor,
temperature=0.5
)
# Example: Train on 32K, inference on 200K
extended_rope = yarn_scaling(
original_scale=1.0,
target_context=200000,
trained_context=32768
)
# Enables 6x context extension with minimal quality loss Putting It All Together
Modern long-context models combine these techniques:
Technical Specifications
- Sparse Attention
- Sliding window + global
- Position Encoding
- RoPE with interpolation
- Efficient Implementation
- FlashAttention 2
- KV Cache
- GQA (8-16 heads)
- Training
- Progressive length scaling
Example: Claude 3 Architecture (estimated)
# Estimated Claude 3 Opus architecture
class Claude3Opus:
num_layers = 80
d_model = 8192
num_heads = 64
num_kv_heads = 8 # GQA
max_context = 200000
# Sparse attention pattern
window_size = 4096 # Local window
global_attention_every_n = 8 # Global attention layers
# Position encoding
position_encoding = "RoPE"
rope_theta = 10000
rope_scaling = "YaRN" # For context extension
# Optimization
attention_impl = "FlashAttention2"
kv_cache_quantization = "FP8" # Reduce memory
def estimate_memory(self):
# Model weights: ~150GB (70B params * 2 bytes)
# KV cache: ~50GB (with GQA and quantization)
# Activation: ~20GB (recomputed with gradient checkpointing)
return 220 # GB total Performance Implications
Time-to-First-Token (TTFT):
- Scales linearly with context length
- 200K context ≈ 5-10s prefill on modern GPUs
Tokens-per-Second (TPS):
- Generation speed ~constant regardless of context
- Bottleneck is memory bandwidth, not compute
Accuracy:
- “Lost in the middle” problem persists
- Performance degrades 5-10% for middle-context retrieval
The Future: 1M+ Contexts
Google’s Gemini 1.5 achieves 1M tokens through:
- More aggressive sparse attention
- Novel compression techniques
- Hierarchical processing (summarize chunks)
This enables:
- Processing entire codebases (100K+ lines)
- Analyzing multiple books simultaneously
- Hours of video/audio transcripts
Resources & Further Reading
Foundational Papers
Efficient Attention
- FlashAttention: Fast and Memory-Efficient Exact Attention (Dao et al., 2022)
- IO-aware algorithm enabling longer contexts
- Used in GPT-4, Claude, and other production models
- FlashAttention-2 (Dao, 2023)
- 2× faster with better parallelization
Sparse Attention Patterns
- Longformer: The Long-Document Transformer (Beltagy et al., 2020)
- Sliding window + global attention
- 4096 token context
- BigBird: Transformers for Longer Sequences (Zaheer et al., 2020)
- Random + window + global attention
- Theoretical analysis showing sparse attention can approximate full attention
Linear Attention
- Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)
- Project keys and values to lower dimensions
- Performer: Fast Attention via Orthogonal Random Features (Choromanski et al., 2021)
- Kernel-based approximation of softmax attention
Long-Range Models
- Extending Context Window of LLMs via RoPE Scaling (Chen et al., 2023)
- Scale RoPE for longer contexts
- YaRN: Efficient Context Window Extension (Peng et al., 2023)
- NTK-aware scaling for better interpolation
Implementation Resources
- Flash Attention: Optimized CUDA kernels
- Longformer: Hugging Face implementation
- xFormers: Efficient attention variants
- vLLM: Production serving with long contexts
Related Technical Guides
- Attention Mechanisms → - Understanding attention fundamentals
- Transformer Architecture → - Base architecture
- KV Cache Optimization → - Memory management
- Position Encodings → - RoPE scaling for long contexts
- Inference Optimization → - Speed optimizations
Educational Content
Blog Posts
Video Lectures
Advanced Topics
State Space Models
-
Mamba: Linear-Time Sequence Modeling
- Alternative to attention with linear complexity
-
Structured State Space Models (S4)
- Convolutional approach to sequence modeling
Hybrid Approaches
Last updated: November 2025