KV Cache and Memory Management
Deep dive into KV cache optimization - the key to fast and efficient LLM inference
KV Cache and Memory Management
The KV cache is critical for fast LLM generation. Poor management = slow inference and out-of-memory errors.
During autoregressive text generation, transformers produce one token at a time. Naively, we’d recompute attention over the entire sequence for each new token—a computational disaster scaling as where is sequence length. The KV cache elegantly solves this by storing previously computed Key and Value matrices, reducing per-token cost from to .
However, this cache becomes the memory bottleneck for long-context inference. For a 70B parameter model generating 4K tokens, the KV cache alone can consume 5+ GB of memory—more than the activation memory! Understanding and optimizing KV cache usage is essential for deploying large language models efficiently.
The Problem: Redundant Computation
Recall the attention formula:
During generation at step , we compute:
- Query for new token:
- Keys for all tokens:
- Values for all tokens:
The inefficiency: At step , we need again! Without caching, we recompute these from the original embeddings:
For 2048 tokens, this is 2 million unnecessary computations!
Generate "The cat sat on the mat"
Step 1: "The" → compute attention
Step 2: "The cat" → recompute attention for "The", compute for "cat"
Step 3: "The cat sat" → recompute for "The", "cat", compute for "sat"
...
This is O(n²) - extremely slow!
KV Cache Solution: Store Key and Value tensors from previous tokens:
class AttentionWithCache:
def __init__(self, d_model=512):
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# Cache for keys and values
self.cached_keys = []
self.cached_values = []
def forward(self, x_new, use_cache=True):
"""
x_new: just the new token [1, d_model]
"""
# Compute Q, K, V for new token only
q_new = self.W_q(x_new) # [1, d_model]
k_new = self.W_k(x_new) # [1, d_model]
v_new = self.W_v(x_new) # [1, d_model]
if use_cache:
# Add to cache
self.cached_keys.append(k_new)
self.cached_values.append(v_new)
# Use all cached keys/values
K = torch.cat(self.cached_keys, dim=0) # [seq_len, d_model]
V = torch.cat(self.cached_values, dim=0)
else:
K, V = k_new, v_new
# Attention computation
scores = q_new @ K.T # [1, seq_len]
weights = softmax(scores)
output = weights @ V # [1, d_model]
return output
# Speed comparison:
# Without cache: 2 tokens/sec
# With cache: 50 tokens/sec
# 25x speedup! Memory Analysis: The Storage Cost
The KV cache stores two tensors per layer per attention head. For a model with:
- layers
- attention heads
- dimension per head
- sequence length
- batch size
The total KV cache size is:
Concrete example for a 70B parameter model (LLaMA-2 70B):
- layers
- heads
- dimensions per head
- tokens (context length)
- (single user)
- FP16 = 2 bytes
That’s 10.7 GB just to store attention keys and values for a single 4K context! For comparison, the model weights themselves are 140 GB in FP16.
Batch size scaling: With 8 concurrent users (batch size 8):
This exceeds the memory of a single A100 (80GB), forcing expensive multi-GPU inference or restricting batch sizes—directly impacting throughput and cost.
The throughput killer: Larger batch sizes amortize the model weight loading cost across multiple requests, dramatically improving tokens/second/dollar. But KV cache memory limits how large batches can grow. This is why PagedAttention and vLLM were such breakthroughs—they enable 5-10x higher batch sizes through clever memory management.
The batch size economics: For a 70B model on an A100 (80GB):
Without optimization:
This is impossible! The weights alone exceed GPU memory. With quantization (INT8 weights = 70GB):
Only one request at a time. Terrible utilization!
With PagedAttention + quantization + optimizations:
This 8-16x improvement in batch size translates directly to cost savings. If you’re serving 1M requests/day:
- Batch size 1: Need 16 GPUs → \$48K/month
- Batch size 16: Need 1 GPU → \$3K/month
\$45K/month savings from KV cache optimization alone!
where:
- = batch size
- = sequence length
- = number of layers
- = number of attention heads
- = dimension per head
- bytes = precision (2 for FP16, 4 for FP32)
Example: LLaMA 2 70B
- layers
- heads
- dims/head
- tokens
- FP16 (2 bytes)
That’s 10.7 GB just for KV cache! Compare to:
- Model weights: ~140 GB (FP16)
- Activations: ~3 GB
- KV cache: ~11 GB (significant!)
Scaling behavior: Memory grows linearly with sequence length. Double the context → double the KV memory.
def calculate_kv_cache_size(
batch_size=1,
seq_len=4096,
num_layers=80,
num_heads=64,
head_dim=128,
precision='fp16'
):
"""
Calculate KV cache memory requirements
"""
bytes_per_element = {
'fp32': 4,
'fp16': 2,
'int8': 1
}[precision]
# K and V for each layer
kv_cache_size = (
2 * # K and V
batch_size *
seq_len *
num_layers *
num_heads *
head_dim *
bytes_per_element
)
return kv_cache_size / 1e9 # Convert to GB
# Example: LLaMA 2 70B
kv_size = calculate_kv_cache_size(
batch_size=1,
seq_len=4096,
num_layers=80,
num_heads=64,
head_dim=128,
precision='fp16'
)
print(f"KV cache size: {kv_size:.2f} GB")
# Output: 5.37 GB for single request!
# At 4K context:
# - Model weights: 140 GB
# - KV cache: 5.4 GB
# - Activations: ~3 GB
# Total: ~148 GB
# At 32K context (8× longer):
# - Model weights: 140 GB
# - KV cache: 43 GB ← 8× larger!
# - Activations: ~3 GB
# Total: ~186 GB The Problem:
Batch size 32, 4K context each:
├─ KV cache per request: 5.4 GB
├─ Total KV cache: 172 GB
└─ Doesn't fit on most GPUs!
Result: Either small batches (slow) or OOM errors
Optimization 1: PagedAttention (vLLM)
Treat KV cache like virtual memory with pages:
class PagedKVCache:
"""
Store KV cache in fixed-size blocks (pages)
Like virtual memory paging in OS
"""
def __init__(self, block_size=16, num_blocks=1024):
self.block_size = block_size # tokens per block
self.num_blocks = num_blocks
# Physical memory: pre-allocated blocks
self.physical_blocks = torch.zeros(
num_blocks, 2, block_size, num_heads, head_dim
) # [num_blocks, 2 (K/V), block_size, heads, head_dim]
# Track which blocks are free
self.free_blocks = set(range(num_blocks))
# Logical to physical mapping
self.block_tables = {} # request_id -> [block_ids]
def allocate_blocks(self, request_id, num_tokens):
"""Allocate blocks for a request"""
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
# Get free blocks
if len(self.free_blocks) < num_blocks_needed:
raise MemoryError("Out of KV cache blocks")
allocated = []
for _ in range(num_blocks_needed):
block_id = self.free_blocks.pop()
allocated.append(block_id)
self.block_tables[request_id] = allocated
return allocated
def write_kv(self, request_id, token_idx, k, v):
"""Write K, V to cache"""
# Find which block
block_idx = token_idx // self.block_size
offset = token_idx % self.block_size
# Get physical block
physical_block_id = self.block_tables[request_id][block_idx]
# Write
self.physical_blocks[physical_block_id, 0, offset] = k
self.physical_blocks[physical_block_id, 1, offset] = v
def free_request(self, request_id):
"""Free blocks when request completes"""
blocks = self.block_tables.pop(request_id)
self.free_blocks.update(blocks)
# Benefits:
# 1. No memory fragmentation
# 2. Easy sharing between requests (same prompt)
# 3. Can evict least-recently-used blocks
# 4. Near-zero waste (vs 30% waste in naive implementation)
# vLLM achieves 20-24x higher throughput with PagedAttention Memory Savings:
Naive KV Cache:
├─ Allocate max_seq_len upfront (4096 tokens)
├─ Actual sequence: 237 tokens
├─ Wasted: 3859 tokens (94%!)
└─ With 32 batch: 30% memory wasted on average
PagedAttention:
├─ Allocate only needed blocks
├─ Actual sequence: 237 tokens → 15 blocks
├─ Wasted: <16 tokens per request (<1%)
└─ 5-7x better memory utilization
Optimization 2: Multi-Query Attention (MQA)
Share K, V across attention heads:
class MultiQueryAttention:
"""
Use same K, V for all heads (only Q differs)
Standard: Each head has own Q, K, V
MQA: Each head has own Q, shared K, V
"""
def __init__(self, d_model=512, num_heads=8):
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Separate Q for each head
self.W_q = nn.Linear(d_model, d_model)
# Shared K, V (much smaller!)
self.W_k = nn.Linear(d_model, self.d_k)
self.W_v = nn.Linear(d_model, self.d_k)
def forward(self, x):
batch, seq_len, d_model = x.shape
# Q: separate for each head
Q = self.W_q(x).reshape(batch, seq_len, self.num_heads, self.d_k)
# K, V: shared across heads
K = self.W_k(x).unsqueeze(2) # [batch, seq, 1, d_k]
V = self.W_v(x).unsqueeze(2)
# Broadcast K, V to all heads
K = K.expand(-1, -1, self.num_heads, -1)
V = V.expand(-1, -1, self.num_heads, -1)
# Rest is standard attention
scores = (Q @ K.transpose(-2, -1)) / sqrt(self.d_k)
weights = softmax(scores, dim=-1)
output = weights @ V
return output.reshape(batch, seq_len, d_model)
# KV cache reduction:
# Standard (64 heads): 2 * 64 * seq_len * d_k
# MQA (64 heads): 2 * 1 * seq_len * d_k
# 64x smaller KV cache!
# Used in: PaLM, Falcon
# Quality: ~1-2% worse than standard attention
# Speed: Much faster generation Optimization 3: Grouped-Query Attention (GQA)
Compromise between standard and MQA:
class GroupedQueryAttention:
"""
Group heads, share K/V within groups
Example with 8 heads, 2 groups:
- Group 1 (heads 0-3): Shared K, V
- Group 2 (heads 4-7): Shared K, V
"""
def __init__(self, d_model=512, num_heads=8, num_kv_heads=2):
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.d_k = d_model // num_heads
# Q for all heads
self.W_q = nn.Linear(d_model, d_model)
# K, V for kv_heads only
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
def forward(self, x):
# Q: all heads
Q = self.W_q(x).reshape(batch, seq, self.num_heads, self.d_k)
# K, V: kv_heads
K = self.W_k(x).reshape(batch, seq, self.num_kv_heads, self.d_k)
V = self.W_v(x).reshape(batch, seq, self.num_kv_heads, self.d_k)
# Repeat K, V for each group
heads_per_group = self.num_heads // self.num_kv_heads
K = K.repeat_interleave(heads_per_group, dim=2)
V = V.repeat_interleave(heads_per_group, dim=2)
# Standard attention
scores = (Q @ K.transpose(-2, -1)) / sqrt(self.d_k)
weights = softmax(scores, dim=-1)
output = weights @ V
return output.reshape(batch, seq, d_model)
# LLaMA 2 uses GQA:
# - 64 query heads
# - 8 KV heads
# - 8x smaller KV cache vs standard
# - Better quality than MQA
# - Nearly as fast
# GQA is becoming the standard (LLaMA 2, Mistral, Mixtral) Comparison:
Technical Specifications
- Standard Attention
- 64 KV heads, best quality, largest cache
- GQA (LLaMA 2)
- 8 KV heads, 8x smaller cache, -1% quality
- MQA (Falcon)
- 1 KV head, 64x smaller cache, -2% quality
Optimization 4: Quantized KV Cache
Store cache in INT8 instead of FP16:
class QuantizedKVCache:
"""
Store KV cache in INT8 (1 byte) instead of FP16 (2 bytes)
50% memory savings
"""
def __init__(self):
self.k_cache_int8 = []
self.v_cache_int8 = []
self.k_scales = []
self.v_scales = []
def quantize(self, tensor):
"""FP16 → INT8"""
# Find scale factor
abs_max = tensor.abs().max()
scale = abs_max / 127
# Quantize
tensor_int8 = (tensor / scale).round().clamp(-128, 127).to(torch.int8)
return tensor_int8, scale
def dequantize(self, tensor_int8, scale):
"""INT8 → FP16"""
return tensor_int8.float() * scale
def store(self, k, v):
k_int8, k_scale = self.quantize(k)
v_int8, v_scale = self.quantize(v)
self.k_cache_int8.append(k_int8)
self.v_cache_int8.append(v_int8)
self.k_scales.append(k_scale)
self.v_scales.append(v_scale)
def retrieve(self):
K_list = [
self.dequantize(k, scale)
for k, scale in zip(self.k_cache_int8, self.k_scales)
]
V_list = [
self.dequantize(v, scale)
for v, scale in zip(self.v_cache_int8, self.v_scales)
]
return torch.cat(K_list), torch.cat(V_list)
# Memory savings:
# FP16 KV cache: 5.4 GB
# INT8 KV cache: 2.7 GB (50% reduction)
# Quality: <0.5% degradation
# Speed: Slightly slower (dequantization overhead)
# Trade-off: Usually worth it for memory savings Best Practices
Production Deployment:
- Use PagedAttention (vLLM) for serving
- Use GQA in model architecture (if training)
- Quantize KV cache to INT8 if memory-constrained
- Monitor cache hit rates and eviction
- Set appropriate batch sizes based on memory
# vLLM configuration example
from vllm import LLM, SamplingParams
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
# GPU configuration
tensor_parallel_size=2, # Split across 2 GPUs
# KV cache configuration
max_num_seqs=256, # Max concurrent requests
max_num_batched_tokens=8192,
# Memory management
gpu_memory_utilization=0.9, # Use 90% of GPU memory
swap_space=4, # 4 GB CPU swap for overflow
# KV cache quantization
kv_cache_dtype="fp8", # or "int8"
)
# This configuration:
# - Handles 256 concurrent requests
# - Uses PagedAttention automatically
# - Quantizes KV cache to FP8
# - Swaps to CPU if GPU memory full Resources & Further Reading
Foundational Papers
KV Cache Optimization
- FlashAttention: Fast and Memory-Efficient Exact Attention (Dao et al., 2022)
- IO-aware algorithm that optimizes KV cache access patterns
- 2-4× speedup with exact attention
- FlashAttention-2 (Dao, 2023)
- Further optimizations, up to 2× faster than FlashAttention-1
Memory Optimization
- Paged Attention (vLLM) (Kwon et al., 2023)
- Virtual memory-inspired KV cache management
- Reduces memory waste from fragmentation
- Multi-Query Attention (Shazeer, 2019)
- Share K and V across attention heads
- Reduces KV cache by 8× for 8-head attention
- Grouped-Query Attention (Ainslie et al., 2023)
- Balance between multi-head and multi-query
- Used in LLaMA 2 70B
Quantization
- LLM.int8(): 8-bit Matrix Multiplication (Dettmers et al., 2022)
- Quantize KV cache to INT8
- 2× memory savings with minimal quality loss
Implementation Resources
- vLLM: Production KV cache management
- Flash Attention: Optimized CUDA kernels
- Text Generation Inference: Hugging Face’s optimized inference
- TensorRT-LLM: NVIDIA’s optimized inference engine
Related Technical Guides
- Attention Mechanisms → - Understanding attention computation
- Transformer Architecture → - Complete architecture overview
- Inference Optimization → - General inference speedups
- Long-Context Architecture → - Handling longer sequences
Educational Content
Blog Posts
- vLLM: Easy, Fast, and Cheap LLM Serving
- Making Deep Learning Go Brrrr - Optimization principles
- FlashAttention Explained
Video Content
- Efficient LLM Inference - Andrej Karpathy
- vLLM Technical Deep Dive
Advanced Topics
Kernel-Level Optimization
System Design
- Continuous Batching
- Speculative Decoding - Generate multiple tokens per step
Last updated: September 2025