We’ve trained a transformer. Now we want to use it—to generate text, token by token. But there’s a problem: the naive approach is horrifically slow.
This notebook explains the problem and its elegant solution: KV-caching. This single optimization takes generation from to —making it practical for real-world use.
The Problem: Redundant Computation¶
Transformers generate text autoregressively—one token at a time. Each new token depends on all previous tokens:
Step 1: [The] → predict "cat"
Step 2: [The, cat] → predict "sat"
Step 3: [The, cat, sat] → predict "on"
Step 4: [The, cat, sat, on] → predict "the"
...The Wasteful Part¶
At each step, we need the model to attend to all previous tokens. Without optimization, this means:
| Step | Tokens Processed | K, V Computed For |
|---|---|---|
| 1 | “The” | “The” |
| 2 | “The cat” | “The” (again!), “cat” |
| 3 | “The cat sat” | “The” (again!), “cat” (again!), “sat” |
| 4 | “The cat sat on” | “The” (again!), “cat” (again!), “sat” (again!), “on” |
We keep recomputing K and V for tokens we’ve already processed! To generate tokens, we process token computations.
Time complexity:
The Key Insight¶
Look back at the attention formula:
For autoregressive generation, K and V for past tokens never change. Once we’ve computed and , those values are fixed forever. Only the new token needs fresh computation.
The solution: cache K and V from previous steps and reuse them.
KV-Cache: The Solution¶
The KV-cache stores computed K and V tensors across generation steps:
Phase 1: Prefill (process the initial prompt)
Input: “The cat” (2 tokens)
Compute: for “The”; for “cat”
Cache: ,
Phase 2: Decode (generate one token at a time)
Input: just the NEW token (“sat”)
Compute: for “sat” only
Retrieve:
Concatenate:
Attention: New token’s Q attends to all K, V
Now we process exactly 1 token per step. Time complexity:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import timeImplementation¶
Let’s build attention with KV-cache support. The key change: the forward method accepts an optional cache and returns an updated cache.
class CachedMultiHeadAttention(nn.Module):
"""
Multi-head attention with KV-cache for efficient generation.
During generation:
- First call (prefill): Process full prompt, initialize cache
- Subsequent calls (decode): Process single token, update cache
"""
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Projections
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
"""
Args:
x: (batch, seq_len, d_model) - tokens to process
During prefill: full prompt
During decode: single new token (seq_len=1)
kv_cache: tuple of (K_cached, V_cached) or None
Returns:
output: (batch, seq_len, d_model)
new_cache: tuple of (K_updated, V_updated)
"""
batch_size, seq_len, _ = x.shape
# Project to Q, K, V
# Shape: (batch, seq_len, d_model)
Q = self.W_q(x)
K_new = self.W_k(x)
V_new = self.W_v(x)
# Reshape for multi-head: (batch, num_heads, seq_len, d_k)
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K_new = K_new.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V_new = V_new.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# === THE CACHING MAGIC ===
if kv_cache is not None:
# Append new K, V to cached values
K_cached, V_cached = kv_cache
K = torch.cat([K_cached, K_new], dim=2) # Concat along seq dimension
V = torch.cat([V_cached, V_new], dim=2)
else:
# First call: no cache yet
K = K_new
V = V_new
# Standard scaled dot-product attention
# Q: (batch, heads, seq_q, d_k) - could be 1 during decode
# K: (batch, heads, seq_kv, d_k) - grows each step
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V)
# Reshape and project output
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
output = self.W_o(context)
# Return output AND updated cache
return output, (K, V)The Critical Lines¶
The entire optimization happens here:
if kv_cache is not None:
K = torch.cat([K_cached, K_new], dim=2)
V = torch.cat([V_cached, V_new], dim=2)Instead of recomputing K and V for all previous tokens, we just concatenate the cached values with the new token’s K and V. The attention mechanism then works exactly as before—it doesn’t know or care that we assembled K and V from cached parts.
Demonstration¶
Let’s trace through how the cache grows during generation:
# Setup
d_model = 64
num_heads = 4
attn = CachedMultiHeadAttention(d_model, num_heads)
print("=" * 60)
print("PHASE 1: PREFILL (process initial prompt)")
print("=" * 60)
# Simulate a 5-token prompt
prompt = torch.randn(1, 5, d_model) # batch=1, seq=5
output, kv_cache = attn(prompt, kv_cache=None)
print(f"Input shape: {prompt.shape} (batch=1, seq=5, d_model={d_model})")
print(f"Output shape: {output.shape}")
print(f"\nCache initialized:")
print(f" K cache: {kv_cache[0].shape} (batch=1, heads={num_heads}, seq=5, d_k={d_model//num_heads})")
print(f" V cache: {kv_cache[1].shape}")============================================================
PHASE 1: PREFILL (process initial prompt)
============================================================
Input shape: torch.Size([1, 5, 64]) (batch=1, seq=5, d_model=64)
Output shape: torch.Size([1, 5, 64])
Cache initialized:
K cache: torch.Size([1, 4, 5, 16]) (batch=1, heads=4, seq=5, d_k=16)
V cache: torch.Size([1, 4, 5, 16])
print("\n" + "=" * 60)
print("PHASE 2: DECODE (generate tokens one at a time)")
print("=" * 60)
for step in range(1, 4):
# Simulate generating one new token
# In practice, this would be the embedding of the predicted token
new_token = torch.randn(1, 1, d_model) # seq_len = 1!
# Process ONLY the new token, using cached K, V
output, kv_cache = attn(new_token, kv_cache=kv_cache)
print(f"\nStep {step}:")
print(f" Input: {new_token.shape} (just 1 new token!)")
print(f" Output: {output.shape}")
print(f" K cache: {kv_cache[0].shape} (seq grew to {kv_cache[0].shape[2]})")
print("\n" + "=" * 60)
print(f"After 3 decode steps: cache has {kv_cache[0].shape[2]} positions")
print(f"Started with 5 (prompt) + generated 3 = 8 total")
print("=" * 60)
============================================================
PHASE 2: DECODE (generate tokens one at a time)
============================================================
Step 1:
Input: torch.Size([1, 1, 64]) (just 1 new token!)
Output: torch.Size([1, 1, 64])
K cache: torch.Size([1, 4, 6, 16]) (seq grew to 6)
Step 2:
Input: torch.Size([1, 1, 64]) (just 1 new token!)
Output: torch.Size([1, 1, 64])
K cache: torch.Size([1, 4, 7, 16]) (seq grew to 7)
Step 3:
Input: torch.Size([1, 1, 64]) (just 1 new token!)
Output: torch.Size([1, 1, 64])
K cache: torch.Size([1, 4, 8, 16]) (seq grew to 8)
============================================================
After 3 decode steps: cache has 8 positions
Started with 5 (prompt) + generated 3 = 8 total
============================================================
Complexity Analysis¶
Let’s quantify the improvement. Consider generating tokens after a prompt of length :
Without KV-Cache¶
| Step | Tokens Processed | K, V Computed |
|---|---|---|
| 1 | p + 1 | p + 1 |
| 2 | p + 2 | p + 2 |
| ... | ... | ... |
| n | p + n | p + n |
Total K, V computations:
With KV-Cache¶
| Step | Tokens Processed | K, V Computed |
|---|---|---|
| Prefill | p | p |
| 1 | 1 | 1 |
| 2 | 1 | 1 |
| ... | ... | ... |
| n | 1 | 1 |
Total K, V computations:
Speedup factor:
For generating 100 tokens, this is ~50× fewer computations!
Memory Cost¶
KV-caching trades memory for speed. For each layer, we store:
For our example model (6 layers, 4 heads, d_model=256, seq_len=512):
This is tiny compared to model weights (~100MB for a small model). The tradeoff is overwhelmingly worth it!
| Component | Memory |
|---|---|
| Model weights | ~100 MB |
| KV cache (512 tokens) | ~6 MB |
| Speedup | ~50× |
Benchmark¶
Let’s measure the actual speedup:
# Attention WITHOUT caching (for comparison)
class NoCacheAttention(nn.Module):
"""Standard attention that recomputes everything each time."""
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
context = torch.matmul(attn, V)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.W_o(context)# Setup benchmark
d_model = 256
num_heads = 8
prompt_len = 10
generate_tokens = 50
attn_no_cache = NoCacheAttention(d_model, num_heads)
attn_with_cache = CachedMultiHeadAttention(d_model, num_heads)
# Share weights for fair comparison
attn_with_cache.W_q = attn_no_cache.W_q
attn_with_cache.W_k = attn_no_cache.W_k
attn_with_cache.W_v = attn_no_cache.W_v
attn_with_cache.W_o = attn_no_cache.W_o
print(f"Benchmark: Generate {generate_tokens} tokens after {prompt_len}-token prompt")
print(f"Model: d_model={d_model}, num_heads={num_heads}")
print("=" * 55)Benchmark: Generate 50 tokens after 10-token prompt
Model: d_model=256, num_heads=8
=======================================================
# WITHOUT cache: must reprocess entire sequence each step
prompt = torch.randn(1, prompt_len, d_model)
sequence = prompt.clone()
start = time.time()
for _ in range(generate_tokens):
# Process ENTIRE sequence each time
output = attn_no_cache(sequence)
# Take last position's output as "next token" (simplified)
new_token = output[:, -1:, :]
sequence = torch.cat([sequence, new_token], dim=1)
time_no_cache = time.time() - start
print(f"\nWITHOUT KV-Cache:")
print(f" Final sequence length: {sequence.shape[1]}")
print(f" Time: {time_no_cache*1000:.1f} ms")
WITHOUT KV-Cache:
Final sequence length: 60
Time: 11.6 ms
# WITH cache: process only new tokens
prompt = torch.randn(1, prompt_len, d_model)
start = time.time()
# Prefill: process entire prompt
output, kv_cache = attn_with_cache(prompt, kv_cache=None)
last_output = output[:, -1:, :]
# Decode: process one token at a time
for _ in range(generate_tokens):
output, kv_cache = attn_with_cache(last_output, kv_cache=kv_cache)
last_output = output # Already shape (1, 1, d_model)
time_with_cache = time.time() - start
print(f"\nWITH KV-Cache:")
print(f" Final cache length: {kv_cache[0].shape[2]}")
print(f" Time: {time_with_cache*1000:.1f} ms")
print(f"\n{'='*55}")
print(f"Speedup: {time_no_cache/time_with_cache:.1f}x faster!")
print(f"{'='*55}")
WITH KV-Cache:
Final cache length: 60
Time: 6.2 ms
=======================================================
Speedup: 1.9x faster!
=======================================================
Important Implementation Details¶
1. Position Encodings¶
When using learned or sinusoidal position encodings, the new token must receive the correct position index. If you’ve generated tokens 0-9 and are now generating token 10, it must get position embedding 10, not 0.
# Track position during generation
current_pos = prompt_length
for _ in range(num_tokens):
pos_embed = get_position_embedding(current_pos)
...
current_pos += 12. Causal Masking¶
During prefill, you still need causal masking so position only attends to positions . During decode, no mask is needed—each new token naturally attends only to past tokens (the cache) and itself.
3. Multi-Layer Caching¶
A full transformer has multiple layers, each with its own KV cache:
# Cache is a list of (K, V) tuples, one per layer
kv_caches = [None] * num_layers
for i, layer in enumerate(layers):
x, kv_caches[i] = layer(x, kv_cache=kv_caches[i])Why Every Production System Uses This¶
KV-caching isn’t optional for practical LLM deployment. Consider:
| Scenario | Without Cache | With Cache |
|---|---|---|
| Generate 100 tokens | 5,050 computations | 100 computations |
| Generate 1000 tokens | 500,500 computations | 1,000 computations |
| Generate 4096 tokens | 8,390,656 computations | 4,096 computations |
At 4096 tokens (common context length), caching provides a 2048× reduction in K, V computations. The memory cost (~100MB for a large model) is negligible compared to this speedup.
GPT, Claude, Gemini—every production system uses KV-caching. It’s not an optimization; it’s a requirement for usable latency.
Summary¶
| Aspect | Without Cache | With Cache |
|---|---|---|
| K, V computation | All tokens every step | 1 token per step |
| Time complexity | ||
| Memory | Minimal | Cache grows with context |
| Practical speedup | Baseline | 10-1000× faster |
Key insight: K and V for past tokens are deterministic—compute once, reuse forever.
The tradeoff: Memory for speed. Almost always worth it.
Next: Interpretability¶
We’ve built, trained, and optimized a transformer. But what has it actually learned? The final notebook explores mechanistic interpretability—techniques for peeking inside the black box to understand what patterns and circuits the model has discovered.