Home Technical KV Cache and Memory Management

KV Cache and Memory Management

Deep dive into KV cache optimization - the key to fast and efficient LLM inference

AI Tools Reviews Technical Team
January 26, 2024
LLM technical optimization kv-cache memory

KV Cache and Memory Management

The KV cache is critical for fast LLM generation. Poor management = slow inference and out-of-memory errors.

During autoregressive text generation, transformers produce one token at a time. Naively, we’d recompute attention over the entire sequence for each new token—a computational disaster scaling as O(n2)O(n^2) where nn is sequence length. The KV cache elegantly solves this by storing previously computed Key and Value matrices, reducing per-token cost from O(n)O(n) to O(1)O(1).

However, this cache becomes the memory bottleneck for long-context inference. For a 70B parameter model generating 4K tokens, the KV cache alone can consume 5+ GB of memory—more than the activation memory! Understanding and optimizing KV cache usage is essential for deploying large language models efficiently.

The Problem: Redundant Computation

Recall the attention formula:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

During generation at step tt, we compute:

  • Query for new token: QtQ_t
  • Keys for all tokens: K1,K2,...,KtK_1, K_2, ..., K_t
  • Values for all tokens: V1,V2,...,VtV_1, V_2, ..., V_t

The inefficiency: At step t+1t+1, we need K1,...,KtK_1, ..., K_t again! Without caching, we recompute these from the original embeddings:

Without cache:O(t)compute K1:t+O(1)new Kt+1=O(t) per token\text{Without cache}: \underbrace{O(t)}_{\text{compute } K_{1:t}} + \underbrace{O(1)}_{\text{new } K_{t+1}} = O(t) \text{ per token}

Total for n tokens:t=1nO(t)=O(n2)\text{Total for } n \text{ tokens}: \sum_{t=1}^n O(t) = O(n^2)

For 2048 tokens, this is 2 million unnecessary computations!

Generate "The cat sat on the mat"

Step 1: "The" → compute attention
Step 2: "The cat" → recompute attention for "The", compute for "cat"
Step 3: "The cat sat" → recompute for "The", "cat", compute for "sat"
...

This is O(n²) - extremely slow!

KV Cache Solution: Store Key and Value tensors from previous tokens:

kv_cache_basic.py
class AttentionWithCache:
  def __init__(self, d_model=512):
      self.W_q = nn.Linear(d_model, d_model)
      self.W_k = nn.Linear(d_model, d_model)
      self.W_v = nn.Linear(d_model, d_model)
      
      # Cache for keys and values
      self.cached_keys = []
      self.cached_values = []
  
  def forward(self, x_new, use_cache=True):
      """
      x_new: just the new token [1, d_model]
      """
      # Compute Q, K, V for new token only
      q_new = self.W_q(x_new)  # [1, d_model]
      k_new = self.W_k(x_new)  # [1, d_model]
      v_new = self.W_v(x_new)  # [1, d_model]
      
      if use_cache:
          # Add to cache
          self.cached_keys.append(k_new)
          self.cached_values.append(v_new)
          
          # Use all cached keys/values
          K = torch.cat(self.cached_keys, dim=0)  # [seq_len, d_model]
          V = torch.cat(self.cached_values, dim=0)
      else:
          K, V = k_new, v_new
      
      # Attention computation
      scores = q_new @ K.T  # [1, seq_len]
      weights = softmax(scores)
      output = weights @ V  # [1, d_model]
      
      return output

# Speed comparison:
# Without cache: 2 tokens/sec
# With cache: 50 tokens/sec
# 25x speedup!

Memory Analysis: The Storage Cost

The KV cache stores two tensors per layer per attention head. For a model with:

  • LL layers
  • HH attention heads
  • dhd_h dimension per head
  • nn sequence length
  • bb batch size

The total KV cache size is:

KV cache size=2×L×H×dh×n×b×bytes per element\text{KV cache size} = 2 \times L \times H \times d_h \times n \times b \times \text{bytes per element}

Concrete example for a 70B parameter model (LLaMA-2 70B):

  • L=80L = 80 layers
  • H=64H = 64 heads
  • dh=128d_h = 128 dimensions per head
  • n=4096n = 4096 tokens (context length)
  • b=1b = 1 (single user)
  • FP16 = 2 bytes

KV cache=2×80×64×128×4096×1×2=10.74 GB\text{KV cache} = 2 \times 80 \times 64 \times 128 \times 4096 \times 1 \times 2 = 10.74 \text{ GB}

That’s 10.7 GB just to store attention keys and values for a single 4K context! For comparison, the model weights themselves are 140 GB in FP16.

Batch size scaling: With 8 concurrent users (batch size 8):

KV cache=10.74×8=85.9 GB\text{KV cache} = 10.74 \times 8 = 85.9 \text{ GB}

This exceeds the memory of a single A100 (80GB), forcing expensive multi-GPU inference or restricting batch sizes—directly impacting throughput and cost.

The throughput killer: Larger batch sizes amortize the model weight loading cost across multiple requests, dramatically improving tokens/second/dollar. But KV cache memory limits how large batches can grow. This is why PagedAttention and vLLM were such breakthroughs—they enable 5-10x higher batch sizes through clever memory management.

The batch size economics: For a 70B model on an A100 (80GB):

Without optimization: Max batch size=80textGB140textGBweights10.7textGBperrequestapprox0\text{Max batch size} = \frac{80\\text{ GB} - 140\\text{ GB weights}}{10.7\\text{ GB per request}} \\approx 0

This is impossible! The weights alone exceed GPU memory. With quantization (INT8 weights = 70GB):

Max batch size=807010.7approx1\text{Max batch size} = \frac{80 - 70}{10.7} \\approx 1

Only one request at a time. Terrible utilization!

With PagedAttention + quantization + optimizations: Max batch sizeapprox816\text{Max batch size} \\approx 8-16

This 8-16x improvement in batch size translates directly to cost savings. If you’re serving 1M requests/day:

  • Batch size 1: Need 16 GPUs → \$48K/month
  • Batch size 16: Need 1 GPU → \$3K/month

\$45K/month savings from KV cache optimization alone!

MemoryKV=2×b×n×L×h×dh×bytes\text{Memory}_{\text{KV}} = 2 \times b \times n \times L \times h \times d_h \times \text{bytes}

where:

  • bb = batch size
  • nn = sequence length
  • LL = number of layers
  • hh = number of attention heads
  • dhd_h = dimension per head
  • bytes = precision (2 for FP16, 4 for FP32)

Example: LLaMA 2 70B

  • L=80L = 80 layers
  • h=64h = 64 heads
  • dh=128d_h = 128 dims/head
  • n=4096n = 4096 tokens
  • FP16 (2 bytes)

Memory=2×1×4096×80×64×128×2=10.7 GB\text{Memory} = 2 \times 1 \times 4096 \times 80 \times 64 \times 128 \times 2 = 10.7 \text{ GB}

That’s 10.7 GB just for KV cache! Compare to:

  • Model weights: ~140 GB (FP16)
  • Activations: ~3 GB
  • KV cache: ~11 GB (significant!)

Scaling behavior: Memory grows linearly with sequence length. Double the context → double the KV memory.

kv_cache_memory.py
def calculate_kv_cache_size(
  batch_size=1,
  seq_len=4096,
  num_layers=80,
  num_heads=64,
  head_dim=128,
  precision='fp16'
):
  """
  Calculate KV cache memory requirements
  """
  bytes_per_element = {
      'fp32': 4,
      'fp16': 2,
      'int8': 1
  }[precision]
  
  # K and V for each layer
  kv_cache_size = (
      2 *  # K and V
      batch_size *
      seq_len *
      num_layers *
      num_heads *
      head_dim *
      bytes_per_element
  )
  
  return kv_cache_size / 1e9  # Convert to GB

# Example: LLaMA 2 70B
kv_size = calculate_kv_cache_size(
  batch_size=1,
  seq_len=4096,
  num_layers=80,
  num_heads=64,
  head_dim=128,
  precision='fp16'
)
print(f"KV cache size: {kv_size:.2f} GB")
# Output: 5.37 GB for single request!

# At 4K context:
# - Model weights: 140 GB
# - KV cache: 5.4 GB
# - Activations: ~3 GB
# Total: ~148 GB

# At 32K context (8× longer):
# - Model weights: 140 GB
# - KV cache: 43 GB  ← 8× larger!
# - Activations: ~3 GB
# Total: ~186 GB

The Problem:

Batch size 32, 4K context each:
├─ KV cache per request: 5.4 GB
├─ Total KV cache: 172 GB
└─ Doesn't fit on most GPUs!

Result: Either small batches (slow) or OOM errors

Optimization 1: PagedAttention (vLLM)

Treat KV cache like virtual memory with pages:

paged_attention.py
class PagedKVCache:
  """
  Store KV cache in fixed-size blocks (pages)
  Like virtual memory paging in OS
  """
  def __init__(self, block_size=16, num_blocks=1024):
      self.block_size = block_size  # tokens per block
      self.num_blocks = num_blocks
      
      # Physical memory: pre-allocated blocks
      self.physical_blocks = torch.zeros(
          num_blocks, 2, block_size, num_heads, head_dim
      )  # [num_blocks, 2 (K/V), block_size, heads, head_dim]
      
      # Track which blocks are free
      self.free_blocks = set(range(num_blocks))
      
      # Logical to physical mapping
      self.block_tables = {}  # request_id -> [block_ids]
  
  def allocate_blocks(self, request_id, num_tokens):
      """Allocate blocks for a request"""
      num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
      
      # Get free blocks
      if len(self.free_blocks) < num_blocks_needed:
          raise MemoryError("Out of KV cache blocks")
      
      allocated = []
      for _ in range(num_blocks_needed):
          block_id = self.free_blocks.pop()
          allocated.append(block_id)
      
      self.block_tables[request_id] = allocated
      return allocated
  
  def write_kv(self, request_id, token_idx, k, v):
      """Write K, V to cache"""
      # Find which block
      block_idx = token_idx // self.block_size
      offset = token_idx % self.block_size
      
      # Get physical block
      physical_block_id = self.block_tables[request_id][block_idx]
      
      # Write
      self.physical_blocks[physical_block_id, 0, offset] = k
      self.physical_blocks[physical_block_id, 1, offset] = v
  
  def free_request(self, request_id):
      """Free blocks when request completes"""
      blocks = self.block_tables.pop(request_id)
      self.free_blocks.update(blocks)

# Benefits:
# 1. No memory fragmentation
# 2. Easy sharing between requests (same prompt)
# 3. Can evict least-recently-used blocks
# 4. Near-zero waste (vs 30% waste in naive implementation)

# vLLM achieves 20-24x higher throughput with PagedAttention

Memory Savings:

Naive KV Cache:
├─ Allocate max_seq_len upfront (4096 tokens)
├─ Actual sequence: 237 tokens
├─ Wasted: 3859 tokens (94%!)
└─ With 32 batch: 30% memory wasted on average

PagedAttention:
├─ Allocate only needed blocks
├─ Actual sequence: 237 tokens → 15 blocks
├─ Wasted: <16 tokens per request (<1%)
└─ 5-7x better memory utilization

Optimization 2: Multi-Query Attention (MQA)

Share K, V across attention heads:

multi_query_attention.py
class MultiQueryAttention:
  """
  Use same K, V for all heads (only Q differs)
  
  Standard: Each head has own Q, K, V
  MQA: Each head has own Q, shared K, V
  """
  def __init__(self, d_model=512, num_heads=8):
      self.num_heads = num_heads
      self.d_k = d_model // num_heads
      
      # Separate Q for each head
      self.W_q = nn.Linear(d_model, d_model)
      
      # Shared K, V (much smaller!)
      self.W_k = nn.Linear(d_model, self.d_k)
      self.W_v = nn.Linear(d_model, self.d_k)
  
  def forward(self, x):
      batch, seq_len, d_model = x.shape
      
      # Q: separate for each head
      Q = self.W_q(x).reshape(batch, seq_len, self.num_heads, self.d_k)
      
      # K, V: shared across heads
      K = self.W_k(x).unsqueeze(2)  # [batch, seq, 1, d_k]
      V = self.W_v(x).unsqueeze(2)
      
      # Broadcast K, V to all heads
      K = K.expand(-1, -1, self.num_heads, -1)
      V = V.expand(-1, -1, self.num_heads, -1)
      
      # Rest is standard attention
      scores = (Q @ K.transpose(-2, -1)) / sqrt(self.d_k)
      weights = softmax(scores, dim=-1)
      output = weights @ V
      
      return output.reshape(batch, seq_len, d_model)

# KV cache reduction:
# Standard (64 heads): 2 * 64 * seq_len * d_k
# MQA (64 heads): 2 * 1 * seq_len * d_k
# 64x smaller KV cache!

# Used in: PaLM, Falcon
# Quality: ~1-2% worse than standard attention
# Speed: Much faster generation

Optimization 3: Grouped-Query Attention (GQA)

Compromise between standard and MQA:

grouped_query_attention.py
class GroupedQueryAttention:
  """
  Group heads, share K/V within groups
  
  Example with 8 heads, 2 groups:
  - Group 1 (heads 0-3): Shared K, V
  - Group 2 (heads 4-7): Shared K, V
  """
  def __init__(self, d_model=512, num_heads=8, num_kv_heads=2):
      self.num_heads = num_heads
      self.num_kv_heads = num_kv_heads
      self.d_k = d_model // num_heads
      
      # Q for all heads
      self.W_q = nn.Linear(d_model, d_model)
      
      # K, V for kv_heads only
      self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
      self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
  
  def forward(self, x):
      # Q: all heads
      Q = self.W_q(x).reshape(batch, seq, self.num_heads, self.d_k)
      
      # K, V: kv_heads
      K = self.W_k(x).reshape(batch, seq, self.num_kv_heads, self.d_k)
      V = self.W_v(x).reshape(batch, seq, self.num_kv_heads, self.d_k)
      
      # Repeat K, V for each group
      heads_per_group = self.num_heads // self.num_kv_heads
      K = K.repeat_interleave(heads_per_group, dim=2)
      V = V.repeat_interleave(heads_per_group, dim=2)
      
      # Standard attention
      scores = (Q @ K.transpose(-2, -1)) / sqrt(self.d_k)
      weights = softmax(scores, dim=-1)
      output = weights @ V
      
      return output.reshape(batch, seq, d_model)

# LLaMA 2 uses GQA:
# - 64 query heads
# - 8 KV heads
# - 8x smaller KV cache vs standard
# - Better quality than MQA
# - Nearly as fast

# GQA is becoming the standard (LLaMA 2, Mistral, Mixtral)

Comparison:

Technical Specifications

Standard Attention
64 KV heads, best quality, largest cache
GQA (LLaMA 2)
8 KV heads, 8x smaller cache, -1% quality
MQA (Falcon)
1 KV head, 64x smaller cache, -2% quality

Optimization 4: Quantized KV Cache

Store cache in INT8 instead of FP16:

quantized_kv_cache.py
class QuantizedKVCache:
  """
  Store KV cache in INT8 (1 byte) instead of FP16 (2 bytes)
  50% memory savings
  """
  def __init__(self):
      self.k_cache_int8 = []
      self.v_cache_int8 = []
      self.k_scales = []
      self.v_scales = []
  
  def quantize(self, tensor):
      """FP16 → INT8"""
      # Find scale factor
      abs_max = tensor.abs().max()
      scale = abs_max / 127
      
      # Quantize
      tensor_int8 = (tensor / scale).round().clamp(-128, 127).to(torch.int8)
      
      return tensor_int8, scale
  
  def dequantize(self, tensor_int8, scale):
      """INT8 → FP16"""
      return tensor_int8.float() * scale
  
  def store(self, k, v):
      k_int8, k_scale = self.quantize(k)
      v_int8, v_scale = self.quantize(v)
      
      self.k_cache_int8.append(k_int8)
      self.v_cache_int8.append(v_int8)
      self.k_scales.append(k_scale)
      self.v_scales.append(v_scale)
  
  def retrieve(self):
      K_list = [
          self.dequantize(k, scale)
          for k, scale in zip(self.k_cache_int8, self.k_scales)
      ]
      V_list = [
          self.dequantize(v, scale)
          for v, scale in zip(self.v_cache_int8, self.v_scales)
      ]
      
      return torch.cat(K_list), torch.cat(V_list)

# Memory savings:
# FP16 KV cache: 5.4 GB
# INT8 KV cache: 2.7 GB (50% reduction)

# Quality: <0.5% degradation
# Speed: Slightly slower (dequantization overhead)
# Trade-off: Usually worth it for memory savings

Best Practices

Production Deployment:

  1. Use PagedAttention (vLLM) for serving
  2. Use GQA in model architecture (if training)
  3. Quantize KV cache to INT8 if memory-constrained
  4. Monitor cache hit rates and eviction
  5. Set appropriate batch sizes based on memory
production_config.py
# vLLM configuration example
from vllm import LLM, SamplingParams

llm = LLM(
  model="meta-llama/Llama-2-70b-hf",
  
  # GPU configuration
  tensor_parallel_size=2,  # Split across 2 GPUs
  
  # KV cache configuration
  max_num_seqs=256,  # Max concurrent requests
  max_num_batched_tokens=8192,
  
  # Memory management
  gpu_memory_utilization=0.9,  # Use 90% of GPU memory
  swap_space=4,  # 4 GB CPU swap for overflow
  
  # KV cache quantization
  kv_cache_dtype="fp8",  # or "int8"
)

# This configuration:
# - Handles 256 concurrent requests
# - Uses PagedAttention automatically
# - Quantizes KV cache to FP8
# - Swaps to CPU if GPU memory full

Resources & Further Reading

Foundational Papers

KV Cache Optimization

Memory Optimization

  • Paged Attention (vLLM) (Kwon et al., 2023)
    • Virtual memory-inspired KV cache management
    • Reduces memory waste from fragmentation
  • Multi-Query Attention (Shazeer, 2019)
    • Share K and V across attention heads
    • Reduces KV cache by 8× for 8-head attention
  • Grouped-Query Attention (Ainslie et al., 2023)
    • Balance between multi-head and multi-query
    • Used in LLaMA 2 70B

Quantization

Implementation Resources

Educational Content

Blog Posts

Video Content

Advanced Topics

Kernel-Level Optimization

System Design


Last updated: September 2025