Home Technical How Long-Context Models Work: Technical Architecture

How Long-Context Models Work: Technical Architecture

Deep dive into the technical innovations that enable models like Claude, Kimi, and GPT-4 to handle 100K+ token contexts

AI Tools Reviews Technical Team
January 20, 2024
LLM technical architecture transformers

How Long-Context Models Work

Processing 100,000+ tokens in a single prompt was once impossible. Standard transformers struggle with sequences beyond 8K tokens due to fundamental computational and memory constraints. So how do models like Claude 3 (200K), Kimi (200K), and GPT-4 Turbo (128K) handle such massive contexts?

This deep-dive explains the technical innovations that make long-context AI possible.

The challenge is brutal: attention complexity scales as O(n2)O(n^2) where nn is sequence length. For 200K tokens, this means computing a 200K × 200K attention matrix—40 billion elements! At FP16 precision, that’s 80GB per layer just for attention scores, before KV cache or activations.

Yet Claude 3 and GPT-4 Turbo handle these massive contexts in real-time. The secret lies in a combination of algorithmic innovations (sparse attention, linear attention approximations), architectural improvements (sliding windows, hierarchical structures), and systems-level optimizations (FlashAttention, efficient KV cache management).

The Core Problem: Quadratic Complexity

Standard transformer attention has O(n²) complexity, where n is the sequence length.

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

The bottleneck is QKTQK^T, which produces an n×nn \times n matrix. For each element:

Sij=1dkk=1dkQikKjkS_{ij} = \frac{1}{\sqrt{d_k}} \sum_{k=1}^{d_k} Q_{ik} K_{jk}

Computing all n2n^2 elements requires O(n2dk)O(n^2 \cdot d_k) operations. Memory is O(n2)O(n^2) to store the attention matrix.

Scaling nightmare:

  • 2K tokens: 4M attention values (16MB at FP32)
  • 8K tokens: 64M values (256MB)
  • 32K tokens: 1B values (4GB)
  • 128K tokens: 16B values (64GB)
  • 200K tokens: 40B values (160GB) ← Impossible without optimization!
attention_complexity.py
import numpy as np

def standard_attention(Q, K, V):
  """
  Standard scaled dot-product attention
  Complexity: O(n²) in sequence length
  """
  # Q, K, V shape: (batch, seq_len, d_model)
  d_k = Q.shape[-1]
  
  # This matrix multiplication is O(n²)
  scores = Q @ K.transpose(-2, -1) / np.sqrt(d_k)
  # scores shape: (batch, seq_len, seq_len)
  
  attention_weights = softmax(scores, dim=-1)
  output = attention_weights @ V
  
  return output

# Memory requirements for attention matrix
seq_lengths = [2048, 8192, 32768, 131072, 200000]

for n in seq_lengths:
  # Attention matrix size in GB (float32)
  memory_gb = (n * n * 4) / (1024**3)
  print(f"Sequence {n:6d}: {memory_gb:8.2f} GB just for attention matrix")

# Output:
# Sequence   2048:     0.02 GB
# Sequence   8192:     0.25 GB
# Sequence  32768:     4.00 GB
# Sequence 131072:    64.00 GB
# Sequence 200000:   148.77 GB  ← Impossible!

At 200K tokens, a single attention operation would require ~150GB just for the attention matrix—before even considering the KV cache or model weights.

Why is this so catastrophic? Consider the GPU memory hierarchy:

GPU Memory Hierarchy (A100):
Registers:     20 MB     ←  1000 GB/s (fastest)
L2 Cache:      40 MB     ←   500 GB/s  
HBM (VRAM):    80 GB     ← 1,935 GB/s (main bottleneck)
Host RAM:     512 GB     ←    25 GB/s (via PCIe)
Disk:           ∞        ←     5 GB/s (unusable for inference)

The 150GB attention matrix doesn’t fit in the 80GB A100 memory. Options:

  1. Multi-GPU: Split across GPUs, but inter-GPU bandwidth is 300 GB/s (NVLink) or 25 GB/s (PCIe)—far slower than HBM
  2. Recomputation: Don’t store the attention matrix, recompute in backward pass—but this doubles compute
  3. Approximation: Use sparse or linear-time alternatives—sacrifices some quality

All successful long-context models choose option 3: approximate attention cleverly. The art is preserving model quality while slashing memory by 10-1000x.

Solution 1: Sparse Attention Patterns

Instead of computing all n2n^2 attention scores, compute only a sparse subset. If each token attends to only ww other tokens:

Complexity: O(nw) instead of O(n2)\text{Complexity: } O(n \cdot w) \text{ instead of } O(n^2)

For w=512w = 512 (constant window), this reduces 200K² = 40B computations to 200K × 512 = 100M—a 400× reduction!

Sliding Window Attention

Each token attends only to a fixed-size local window. For position ii with window size ww:

Attentioni=softmax(QiK[iw:i+w]Tdk)V[iw:i+w]\text{Attention}_i = \text{softmax}\left(\frac{Q_i K_{[i-w:i+w]}^T}{\sqrt{d_k}}\right) V_{[i-w:i+w]}

Trade-offs:

  • ✅ Linear complexity: O(nw)O(n \cdot w)
  • ✅ Good for local dependencies
  • ❌ Cannot capture long-range dependencies directly
  • ❌ Information must propagate through layers
sliding_window.py
def sliding_window_attention(Q, K, V, window_size=512):
  """
  Each token attends only to window_size neighbors
  Complexity: O(n * window_size) ≈ O(n)
  """
  batch, seq_len, d_model = Q.shape
  
  # Create attention mask: 1s in window, 0s elsewhere
  mask = create_sliding_window_mask(seq_len, window_size)
  # mask shape: (seq_len, seq_len)
  # Sparse: only window_size * seq_len non-zero entries
  
  scores = (Q @ K.transpose(-2, -1)) / np.sqrt(d_model)
  scores = scores.masked_fill(mask == 0, float('-inf'))
  
  attention = softmax(scores, dim=-1) @ V
  return attention

# Example: LongFormer-style attention
def longformer_attention(Q, K, V, window_size=512):
  """
  Sliding window + global attention on special tokens
  """
  local_attn = sliding_window_attention(Q, K, V, window_size)
  
  # Global tokens attend to everything
  global_attn = compute_global_attention(Q, K, V, global_token_ids)
  
  return combine_attentions(local_attn, global_attn)

Used by: Longformer, BigBird, Mistral 7B

Trade-off: Reduces ability to relate distant tokens (though global attention helps)

Axial Attention

Decompose 2D attention into row and column attention:

axial_attention.py
def axial_attention(x, axis):
  """
  Attend along one axis at a time
  Complexity: O(n * sqrt(n)) for 2D decomposition
  """
  if axis == 0:
      # Attend along rows
      return attention(rearrange(x, 'b (h w) d -> b h w d'))
  else:
      # Attend along columns
      return attention(rearrange(x, 'b (h w) d -> b w h d'))

# Reshape sequence into 2D grid
seq_len = 65536
grid_size = 256  # sqrt(65536)

# Two axial attention passes ≈ 2 * (seq_len * grid_size)
# = 2 * (65536 * 256) ≈ 33M ops
# vs standard attention: 65536² ≈ 4.3B ops

Solution 2: Efficient Position Encoding

Standard absolute position embeddings break down at long contexts. Modern models use Rotary Position Embeddings (RoPE), which enable extrapolation to longer sequences.

The problem with absolute positions: Learned position embeddings create a lookup table:

pos_emb[i]Rdmodelfor i=0,1,...,nmax\text{pos\_emb}[i] \in \mathbb{R}^{d_{\text{model}}} \quad \text{for } i = 0, 1, ..., n_{\text{max}}

This has fundamental limitations:

  1. Fixed maximum length: Can’t handle positions beyond nmaxn_{\text{max}}
  2. No generalization: Position 10,000 is unrelated to position 9,999 in the embedding space
  3. Poor extrapolation: Model trained on 4K context fails at 8K

RoPE solution: Instead of adding position info to embeddings, rotate the query and key vectors by an angle proportional to their position. The rotation difference naturally encodes relative position.

For a 2D subspace with position mm, the rotation matrix is:

Rm=(cos(mθ)sin(mθ)sin(mθ)cos(mθ))\mathbf{R}_m = \begin{pmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{pmatrix}

The magic: when computing attention between positions mm and nn:

score(m,n)=(mathbfRmq)T(mathbfRnk)=qTmathbfRmTmathbfRnk=qTmathbfRnmk\text{score}(m, n) = (\\mathbf{R}_m q)^T (\\mathbf{R}_n k) = q^T \\mathbf{R}_m^T \\mathbf{R}_n k = q^T \\mathbf{R}_{n-m} k

The score depends only on the relative distance nmn - m! This relative encoding enables extrapolation:

  • Trained on 4K context (positions 0-4095)
  • Model learns “attention decays with distance”
  • Inference on 32K context (positions 0-32767) works because the relative pattern holds

Extrapolation trick: To handle much longer contexts, interpolate the rotation frequencies:

θi=fracθiswhere s=fracntextnewntexttrain\theta'_i = \\frac{\theta_i}{s} \quad \text{where } s = \\frac{n_{\\text{new}}}{n_{\\text{train}}}

This “compresses” the position space, allowing 4K-trained models to handle 32K+ contexts with minimal degradation.

rope.py
import torch

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
  """
  RoPE: Rotate query/key vectors based on position
  Enables length extrapolation beyond training context
  """
  # Split features into pairs
  q_embed = (q * cos) + (rotate_half(q) * sin)
  k_embed = (k * cos) + (rotate_half(k) * sin)
  return q_embed, k_embed

def rotate_half(x):
  """Helper: rotate features for RoPE"""
  x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
  return torch.cat((-x2, x1), dim=-1)

# RoPE advantages:
# 1. Relative position encoding (better generalization)
# 2. Decays attention with distance naturally
# 3. Can extrapolate to longer sequences than training
# 4. No learned parameters

# Example: Extend context at inference
def extend_rope_context(original_context=4096, new_context=32768):
  scale = new_context / original_context
  # Interpolate position embeddings
  return scale_rope_frequencies(scale)

Used by: LLaMA, GPT-NeoX, Kimi, most modern LLMs

Benefit: Enables context length extrapolation—model can handle longer sequences than seen during training

Solution 3: FlashAttention

Memory-efficient attention implementation that fuses operations:

flash_attention_concept.py
# Standard attention (memory inefficient)
def standard_attention(Q, K, V):
  S = Q @ K.T  # Materialize full attention matrix
  P = softmax(S)
  O = P @ V
  return O

# FlashAttention (conceptual)
def flash_attention(Q, K, V, block_size=128):
  """
  Tile computation to fit in fast SRAM
  Never materialize full attention matrix
  """
  seq_len = Q.shape[1]
  output = torch.zeros_like(Q)
  
  # Process in blocks that fit in SRAM
  for i in range(0, seq_len, block_size):
      for j in range(0, seq_len, block_size):
          # Load blocks into SRAM
          Q_block = Q[:, i:i+block_size]
          K_block = K[:, j:j+block_size]
          V_block = V[:, j:j+block_size]
          
          # Compute attention for this block
          S_block = Q_block @ K_block.T
          P_block = softmax(S_block)
          O_block = P_block @ V_block
          
          # Accumulate result
          output[:, i:i+block_size] += O_block
  
  return output

# Benefits:
# - 2-4x faster than standard attention
# - O(N) memory instead of O(N²)
# - Enables longer contexts with same hardware

Used by: Most production LLMs (GPT-4, Claude, Kimi likely use FlashAttention 2)

Impact: Makes 100K+ contexts feasible on existing hardware

Solution 4: KV Cache Optimization

During generation, cache key and value matrices to avoid recomputation:

kv_cache.py
class KVCache:
  """
  Cache key and value matrices during autoregressive generation
  Reduces O(n²) to O(n) for generation
  """
  def __init__(self, max_seq_len, num_layers, d_model):
      self.max_seq_len = max_seq_len
      # Allocate cache: (num_layers, 2, max_seq_len, d_model)
      self.cache = torch.zeros(num_layers, 2, max_seq_len, d_model)
      self.seq_len = 0
  
  def update(self, layer_idx, k, v):
      """Add new keys/values to cache"""
      batch_size, new_tokens, _ = k.shape
      
      # Store in cache
      start = self.seq_len
      end = start + new_tokens
      self.cache[layer_idx, 0, start:end] = k
      self.cache[layer_idx, 1, start:end] = v
      
      self.seq_len = end
      
      # Return full cached K, V
      return (
          self.cache[layer_idx, 0, :end],
          self.cache[layer_idx, 1, :end]
      )

# Memory calculation for 200K context
def calculate_kv_cache_memory(
  context_length=200000,
  num_layers=80,
  d_model=8192,
  num_heads=64,
  precision="float16"
):
  bytes_per_element = 2 if precision == "float16" else 4
  
  memory_bytes = (
      2 *  # K and V
      num_layers *
      context_length *
      d_model *
      bytes_per_element
  )
  
  memory_gb = memory_bytes / (1024**3)
  print(f"KV Cache: {memory_gb:.2f} GB")
  return memory_gb

# At 200K context with 80 layers:
calculate_kv_cache_memory()
# Output: KV Cache: 48.83 GB

# Optimization: Grouped-Query Attention (GQA)
# Share KV across query heads to reduce memory
def gqa_memory_savings(num_query_heads=64, num_kv_heads=8):
  standard_kv = num_query_heads
  gqa_kv = num_kv_heads
  savings = (standard_kv - gqa_kv) / standard_kv
  print(f"GQA Memory Savings: {savings:.1%}")
  return savings

gqa_memory_savings()
# Output: GQA Memory Savings: 87.5%

Solution 5: Training Strategies

Progressive Length Training

Train on increasingly longer sequences:

Stage 1: 4K context   - 80% of training
Stage 2: 16K context  - 15% of training  
Stage 3: 64K context  - 4% of training
Stage 4: 200K context - 1% of training

Why: Most real-world usage is shorter sequences; focus compute there

Length Extrapolation

Use techniques to extend context beyond training:

length_extrapolation.py
# YaRN: Yet another RoPE extensioN
def yarn_scaling(original_scale, target_context, trained_context):
  """
  Interpolate RoPE frequencies for longer contexts
  """
  scale = target_context / trained_context
  
  # Apply non-uniform scaling: scale less for low frequencies
  low_freq_factor = 1.0
  high_freq_factor = scale
  
  # Interpolate based on frequency
  return interpolate_frequencies(
      low_freq_factor, 
      high_freq_factor,
      temperature=0.5
  )

# Example: Train on 32K, inference on 200K
extended_rope = yarn_scaling(
  original_scale=1.0,
  target_context=200000,
  trained_context=32768
)
# Enables 6x context extension with minimal quality loss

Putting It All Together

Modern long-context models combine these techniques:

Technical Specifications

Sparse Attention
Sliding window + global
Position Encoding
RoPE with interpolation
Efficient Implementation
FlashAttention 2
KV Cache
GQA (8-16 heads)
Training
Progressive length scaling

Example: Claude 3 Architecture (estimated)

claude3_architecture.py
# Estimated Claude 3 Opus architecture
class Claude3Opus:
  num_layers = 80
  d_model = 8192
  num_heads = 64
  num_kv_heads = 8  # GQA
  max_context = 200000
  
  # Sparse attention pattern
  window_size = 4096  # Local window
  global_attention_every_n = 8  # Global attention layers
  
  # Position encoding
  position_encoding = "RoPE"
  rope_theta = 10000
  rope_scaling = "YaRN"  # For context extension
  
  # Optimization
  attention_impl = "FlashAttention2"
  kv_cache_quantization = "FP8"  # Reduce memory
  
  def estimate_memory(self):
      # Model weights: ~150GB (70B params * 2 bytes)
      # KV cache: ~50GB (with GQA and quantization)
      # Activation: ~20GB (recomputed with gradient checkpointing)
      return 220  # GB total

Performance Implications

Time-to-First-Token (TTFT):

  • Scales linearly with context length
  • 200K context ≈ 5-10s prefill on modern GPUs

Tokens-per-Second (TPS):

  • Generation speed ~constant regardless of context
  • Bottleneck is memory bandwidth, not compute

Accuracy:

  • “Lost in the middle” problem persists
  • Performance degrades 5-10% for middle-context retrieval

The Future: 1M+ Contexts

Google’s Gemini 1.5 achieves 1M tokens through:

  • More aggressive sparse attention
  • Novel compression techniques
  • Hierarchical processing (summarize chunks)

This enables:

  • Processing entire codebases (100K+ lines)
  • Analyzing multiple books simultaneously
  • Hours of video/audio transcripts

Resources & Further Reading

Foundational Papers

Efficient Attention

Sparse Attention Patterns

Linear Attention

Long-Range Models

Implementation Resources

Educational Content

Blog Posts

Video Lectures

Advanced Topics

State Space Models

Hybrid Approaches


Last updated: November 2025