Skip to content

Fast Generation with KV-Cache

The Problem: Slow Autoregressive Generation

Section titled “The Problem: Slow Autoregressive Generation”

When generating text, transformers produce one token at a time. After generating each token, we feed the entire sequence back through the model to predict the next token. This means we repeatedly recompute the same values!

# Generate "The cat sat"
# Step 1: Generate token 3
Input: [The, cat]
Compute: K[The], V[The], K[cat], V[cat]
Output: "sat"
# Step 2: Generate token 4
Input: [The, cat, sat]
Compute: K[The], V[The], K[cat], V[cat], K[sat], V[sat] ← Redundant!
Output: "on"
# Step 3: Generate token 5
Input: [The, cat, sat, on]
Compute: K[The], V[The], K[cat], V[cat], K[sat], V[sat], K[on], V[on] ← Redundant!
Output: "the"

For generating n tokens, we process 1 + 2 + 3 + … + n = O(n²) tokens total. Very slow!

Key Insight: In attention, K (Key) and V (Value) for past tokens never change! Only the new token’s query matters. We can cache K and V from previous steps and reuse them.

KV-cache speedup comparison

Two Modes:

  • PREFILL: Process initial prompt, compute and cache K, V for all tokens
  • DECODE: For each new token, compute only its K, V, concatenate with cached values
# PREFILL: Process prompt "The cat"
prompt = [The, cat]
K_all = [K_The, K_cat] # Cache these!
V_all = [V_The, V_cat] # Cache these!
Output: "sat"
# DECODE: Generate next token
new_token = [sat]
K_new = [K_sat] # Only compute for new token
V_new = [V_sat]
K_all = concat(K_cached, K_new) # = [K_The, K_cat, K_sat]
V_all = concat(V_cached, V_new) # = [V_The, V_cat, V_sat]
Output: "on"
# Continue...

Memory Cost: For each layer, we cache K and V tensors with shape (batch, num_heads, seq_len, d_k). For a 6-layer model with d_model=256, 4 heads, and 200-token sequence, this is only ~3 MB per example. Very affordable!

Speed Benefit: Reduces time complexity from O(n²) to O(n) for generating n tokens. Typical speedups:

  • Short sequences (10-20 tokens): 2-5x faster
  • Medium sequences (50-100 tokens): 10-20x faster
  • Long sequences (200+ tokens): 20-50x faster
# KV-cache is enabled by default!
generated = model.generate(
start_tokens,
max_length=100,
sampling_strategy="greedy",
use_cache=True # ← Default!
)
# Disable cache (for debugging/comparison)
generated = model.generate(
start_tokens,
max_length=100,
use_cache=False # ← Much slower!
)
# Benchmark the speedup yourself
python commands/benchmark_generation.py

See the full implementation: