Fast Generation with KV-Cache
The Problem: Slow Autoregressive Generation
Section titled “The Problem: Slow Autoregressive Generation”When generating text, transformers produce one token at a time. After generating each token, we feed the entire sequence back through the model to predict the next token. This means we repeatedly recompute the same values!
Example without cache:
Section titled “Example without cache:”# Generate "The cat sat"# Step 1: Generate token 3Input: [The, cat]Compute: K[The], V[The], K[cat], V[cat]Output: "sat" ✓
# Step 2: Generate token 4Input: [The, cat, sat]Compute: K[The], V[The], K[cat], V[cat], K[sat], V[sat] ← Redundant!Output: "on"
# Step 3: Generate token 5Input: [The, cat, sat, on]Compute: K[The], V[The], K[cat], V[cat], K[sat], V[sat], K[on], V[on] ← Redundant!Output: "the"For generating n tokens, we process 1 + 2 + 3 + … + n = O(n²) tokens total. Very slow!
The Solution: KV-Cache
Section titled “The Solution: KV-Cache”Key Insight: In attention, K (Key) and V (Value) for past tokens never change! Only the new token’s query matters. We can cache K and V from previous steps and reuse them.
How It Works
Section titled “How It Works”Two Modes:
- PREFILL: Process initial prompt, compute and cache K, V for all tokens
- DECODE: For each new token, compute only its K, V, concatenate with cached values
Implementation
Section titled “Implementation”# PREFILL: Process prompt "The cat"prompt = [The, cat]K_all = [K_The, K_cat] # Cache these!V_all = [V_The, V_cat] # Cache these!Output: "sat"
# DECODE: Generate next tokennew_token = [sat]K_new = [K_sat] # Only compute for new tokenV_new = [V_sat]K_all = concat(K_cached, K_new) # = [K_The, K_cat, K_sat]V_all = concat(V_cached, V_new) # = [V_The, V_cat, V_sat]Output: "on"
# Continue...Memory vs Speed Tradeoff
Section titled “Memory vs Speed Tradeoff”Memory Cost: For each layer, we cache K and V tensors with shape (batch, num_heads, seq_len, d_k). For a 6-layer model with d_model=256, 4 heads, and 200-token sequence, this is only ~3 MB per example. Very affordable!
Speed Benefit: Reduces time complexity from O(n²) to O(n) for generating n tokens. Typical speedups:
- Short sequences (10-20 tokens): 2-5x faster
- Medium sequences (50-100 tokens): 10-20x faster
- Long sequences (200+ tokens): 20-50x faster
Using KV-Cache
Section titled “Using KV-Cache”# KV-cache is enabled by default!generated = model.generate( start_tokens, max_length=100, sampling_strategy="greedy", use_cache=True # ← Default!)
# Disable cache (for debugging/comparison)generated = model.generate( start_tokens, max_length=100, use_cache=False # ← Much slower!)
# Benchmark the speedup yourselfpython commands/benchmark_generation.pyFull Code
Section titled “Full Code”See the full implementation: