Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

KV-Cache

We’ve trained a transformer. Now we want to use it—to generate text, token by token. But there’s a problem: the naive approach is horrifically slow.

This notebook explains the problem and its elegant solution: KV-caching. This single optimization takes generation from O(n2)O(n^2) to O(n)O(n)—making it practical for real-world use.

The Problem: Redundant Computation

Transformers generate text autoregressively—one token at a time. Each new token depends on all previous tokens:

Step 1: [The] → predict "cat"
Step 2: [The, cat] → predict "sat"
Step 3: [The, cat, sat] → predict "on"
Step 4: [The, cat, sat, on] → predict "the"
...

The Wasteful Part

At each step, we need the model to attend to all previous tokens. Without optimization, this means:

StepTokens ProcessedK, V Computed For
1“The”“The”
2“The cat”“The” (again!), “cat”
3“The cat sat”“The” (again!), “cat” (again!), “sat”
4“The cat sat on”“The” (again!), “cat” (again!), “sat” (again!), “on”

We keep recomputing K and V for tokens we’ve already processed! To generate nn tokens, we process 1+2+3+...+n=n(n+1)21 + 2 + 3 + ... + n = \frac{n(n+1)}{2} token computations.

Time complexity: O(n2)O(n^2)

The Key Insight

Look back at the attention formula:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

For autoregressive generation, K and V for past tokens never change. Once we’ve computed K"The"K_{\text{"The"}} and V"The"V_{\text{"The"}}, those values are fixed forever. Only the new token needs fresh computation.

The solution: cache K and V from previous steps and reuse them.

KV-Cache: The Solution

The KV-cache stores computed K and V tensors across generation steps:

Phase 1: Prefill (process the initial prompt)

  • Input: “The cat” (2 tokens)

  • Compute: K0,V0K_0, V_0 for “The”; K1,V1K_1, V_1 for “cat”

  • Cache: [K0,K1][K_0, K_1], [V0,V1][V_0, V_1]

Phase 2: Decode (generate one token at a time)

  • Input: just the NEW token (“sat”)

  • Compute: K2,V2K_2, V_2 for “sat” only

  • Retrieve: Kcached=[K0,K1]K_{\text{cached}} = [K_0, K_1]

  • Concatenate: Kfull=[K0,K1,K2]K_{\text{full}} = [K_0, K_1, K_2]

  • Attention: New token’s Q attends to all K, V

Now we process exactly 1 token per step. Time complexity: O(n)O(n)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time

Implementation

Let’s build attention with KV-cache support. The key change: the forward method accepts an optional cache and returns an updated cache.

class CachedMultiHeadAttention(nn.Module):
    """
    Multi-head attention with KV-cache for efficient generation.
    
    During generation:
    - First call (prefill): Process full prompt, initialize cache
    - Subsequent calls (decode): Process single token, update cache
    """
    
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x, kv_cache=None):
        """
        Args:
            x: (batch, seq_len, d_model) - tokens to process
               During prefill: full prompt
               During decode: single new token (seq_len=1)
            kv_cache: tuple of (K_cached, V_cached) or None
        
        Returns:
            output: (batch, seq_len, d_model)
            new_cache: tuple of (K_updated, V_updated)
        """
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        # Shape: (batch, seq_len, d_model)
        Q = self.W_q(x)
        K_new = self.W_k(x)
        V_new = self.W_v(x)
        
        # Reshape for multi-head: (batch, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K_new = K_new.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V_new = V_new.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # === THE CACHING MAGIC ===
        if kv_cache is not None:
            # Append new K, V to cached values
            K_cached, V_cached = kv_cache
            K = torch.cat([K_cached, K_new], dim=2)  # Concat along seq dimension
            V = torch.cat([V_cached, V_new], dim=2)
        else:
            # First call: no cache yet
            K = K_new
            V = V_new
        
        # Standard scaled dot-product attention
        # Q: (batch, heads, seq_q, d_k) - could be 1 during decode
        # K: (batch, heads, seq_kv, d_k) - grows each step
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)
        
        # Reshape and project output
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(context)
        
        # Return output AND updated cache
        return output, (K, V)

The Critical Lines

The entire optimization happens here:

if kv_cache is not None:
    K = torch.cat([K_cached, K_new], dim=2)
    V = torch.cat([V_cached, V_new], dim=2)

Instead of recomputing K and V for all previous tokens, we just concatenate the cached values with the new token’s K and V. The attention mechanism then works exactly as before—it doesn’t know or care that we assembled K and V from cached parts.

Demonstration

Let’s trace through how the cache grows during generation:

# Setup
d_model = 64
num_heads = 4
attn = CachedMultiHeadAttention(d_model, num_heads)

print("=" * 60)
print("PHASE 1: PREFILL (process initial prompt)")
print("=" * 60)

# Simulate a 5-token prompt
prompt = torch.randn(1, 5, d_model)  # batch=1, seq=5
output, kv_cache = attn(prompt, kv_cache=None)

print(f"Input shape:  {prompt.shape}  (batch=1, seq=5, d_model={d_model})")
print(f"Output shape: {output.shape}")
print(f"\nCache initialized:")
print(f"  K cache: {kv_cache[0].shape}  (batch=1, heads={num_heads}, seq=5, d_k={d_model//num_heads})")
print(f"  V cache: {kv_cache[1].shape}")
============================================================
PHASE 1: PREFILL (process initial prompt)
============================================================
Input shape:  torch.Size([1, 5, 64])  (batch=1, seq=5, d_model=64)
Output shape: torch.Size([1, 5, 64])

Cache initialized:
  K cache: torch.Size([1, 4, 5, 16])  (batch=1, heads=4, seq=5, d_k=16)
  V cache: torch.Size([1, 4, 5, 16])
print("\n" + "=" * 60)
print("PHASE 2: DECODE (generate tokens one at a time)")
print("=" * 60)

for step in range(1, 4):
    # Simulate generating one new token
    # In practice, this would be the embedding of the predicted token
    new_token = torch.randn(1, 1, d_model)  # seq_len = 1!
    
    # Process ONLY the new token, using cached K, V
    output, kv_cache = attn(new_token, kv_cache=kv_cache)
    
    print(f"\nStep {step}:")
    print(f"  Input: {new_token.shape}  (just 1 new token!)")
    print(f"  Output: {output.shape}")
    print(f"  K cache: {kv_cache[0].shape}  (seq grew to {kv_cache[0].shape[2]})")

print("\n" + "=" * 60)
print(f"After 3 decode steps: cache has {kv_cache[0].shape[2]} positions")
print(f"Started with 5 (prompt) + generated 3 = 8 total")
print("=" * 60)

============================================================
PHASE 2: DECODE (generate tokens one at a time)
============================================================

Step 1:
  Input: torch.Size([1, 1, 64])  (just 1 new token!)
  Output: torch.Size([1, 1, 64])
  K cache: torch.Size([1, 4, 6, 16])  (seq grew to 6)

Step 2:
  Input: torch.Size([1, 1, 64])  (just 1 new token!)
  Output: torch.Size([1, 1, 64])
  K cache: torch.Size([1, 4, 7, 16])  (seq grew to 7)

Step 3:
  Input: torch.Size([1, 1, 64])  (just 1 new token!)
  Output: torch.Size([1, 1, 64])
  K cache: torch.Size([1, 4, 8, 16])  (seq grew to 8)

============================================================
After 3 decode steps: cache has 8 positions
Started with 5 (prompt) + generated 3 = 8 total
============================================================

Complexity Analysis

Let’s quantify the improvement. Consider generating nn tokens after a prompt of length pp:

Without KV-Cache

StepTokens ProcessedK, V Computed
1p + 1p + 1
2p + 2p + 2
.........
np + np + n

Total K, V computations:

i=1n(p+i)=np+n(n+1)2=O(n2+np)\sum_{i=1}^{n} (p + i) = np + \frac{n(n+1)}{2} = O(n^2 + np)

With KV-Cache

StepTokens ProcessedK, V Computed
Prefillpp
111
211
.........
n11

Total K, V computations:

p+n=O(n+p)p + n = O(n + p)

Speedup factor: O(n2)O(n)=O(n)\frac{O(n^2)}{O(n)} = O(n)

For generating 100 tokens, this is ~50× fewer computations!

Memory Cost

KV-caching trades memory for speed. For each layer, we store:

Cache size=2×batch×num_heads×seq_len×dk×4 bytes\text{Cache size} = 2 \times \text{batch} \times \text{num\_heads} \times \text{seq\_len} \times d_k \times 4 \text{ bytes}

For our example model (6 layers, 4 heads, d_model=256, seq_len=512):

=6×2×1×4×512×64×4=6.3 MB= 6 \times 2 \times 1 \times 4 \times 512 \times 64 \times 4 = 6.3 \text{ MB}

This is tiny compared to model weights (~100MB for a small model). The tradeoff is overwhelmingly worth it!

ComponentMemory
Model weights~100 MB
KV cache (512 tokens)~6 MB
Speedup~50×

Benchmark

Let’s measure the actual speedup:

# Attention WITHOUT caching (for comparison)
class NoCacheAttention(nn.Module):
    """Standard attention that recomputes everything each time."""
    
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, V)
        
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.W_o(context)
# Setup benchmark
d_model = 256
num_heads = 8
prompt_len = 10
generate_tokens = 50

attn_no_cache = NoCacheAttention(d_model, num_heads)
attn_with_cache = CachedMultiHeadAttention(d_model, num_heads)

# Share weights for fair comparison
attn_with_cache.W_q = attn_no_cache.W_q
attn_with_cache.W_k = attn_no_cache.W_k
attn_with_cache.W_v = attn_no_cache.W_v
attn_with_cache.W_o = attn_no_cache.W_o

print(f"Benchmark: Generate {generate_tokens} tokens after {prompt_len}-token prompt")
print(f"Model: d_model={d_model}, num_heads={num_heads}")
print("=" * 55)
Benchmark: Generate 50 tokens after 10-token prompt
Model: d_model=256, num_heads=8
=======================================================
# WITHOUT cache: must reprocess entire sequence each step
prompt = torch.randn(1, prompt_len, d_model)
sequence = prompt.clone()

start = time.time()
for _ in range(generate_tokens):
    # Process ENTIRE sequence each time
    output = attn_no_cache(sequence)
    # Take last position's output as "next token" (simplified)
    new_token = output[:, -1:, :]
    sequence = torch.cat([sequence, new_token], dim=1)
time_no_cache = time.time() - start

print(f"\nWITHOUT KV-Cache:")
print(f"  Final sequence length: {sequence.shape[1]}")
print(f"  Time: {time_no_cache*1000:.1f} ms")

WITHOUT KV-Cache:
  Final sequence length: 60
  Time: 11.6 ms
# WITH cache: process only new tokens
prompt = torch.randn(1, prompt_len, d_model)

start = time.time()

# Prefill: process entire prompt
output, kv_cache = attn_with_cache(prompt, kv_cache=None)
last_output = output[:, -1:, :]

# Decode: process one token at a time
for _ in range(generate_tokens):
    output, kv_cache = attn_with_cache(last_output, kv_cache=kv_cache)
    last_output = output  # Already shape (1, 1, d_model)

time_with_cache = time.time() - start

print(f"\nWITH KV-Cache:")
print(f"  Final cache length: {kv_cache[0].shape[2]}")
print(f"  Time: {time_with_cache*1000:.1f} ms")

print(f"\n{'='*55}")
print(f"Speedup: {time_no_cache/time_with_cache:.1f}x faster!")
print(f"{'='*55}")

WITH KV-Cache:
  Final cache length: 60
  Time: 6.2 ms

=======================================================
Speedup: 1.9x faster!
=======================================================

Important Implementation Details

1. Position Encodings

When using learned or sinusoidal position encodings, the new token must receive the correct position index. If you’ve generated tokens 0-9 and are now generating token 10, it must get position embedding 10, not 0.

# Track position during generation
current_pos = prompt_length
for _ in range(num_tokens):
    pos_embed = get_position_embedding(current_pos)
    ...
    current_pos += 1

2. Causal Masking

During prefill, you still need causal masking so position ii only attends to positions i\leq i. During decode, no mask is needed—each new token naturally attends only to past tokens (the cache) and itself.

3. Multi-Layer Caching

A full transformer has multiple layers, each with its own KV cache:

# Cache is a list of (K, V) tuples, one per layer
kv_caches = [None] * num_layers

for i, layer in enumerate(layers):
    x, kv_caches[i] = layer(x, kv_cache=kv_caches[i])

Why Every Production System Uses This

KV-caching isn’t optional for practical LLM deployment. Consider:

ScenarioWithout CacheWith Cache
Generate 100 tokens5,050 computations100 computations
Generate 1000 tokens500,500 computations1,000 computations
Generate 4096 tokens8,390,656 computations4,096 computations

At 4096 tokens (common context length), caching provides a 2048× reduction in K, V computations. The memory cost (~100MB for a large model) is negligible compared to this speedup.

GPT, Claude, Gemini—every production system uses KV-caching. It’s not an optimization; it’s a requirement for usable latency.

Summary

AspectWithout CacheWith Cache
K, V computationAll tokens every step1 token per step
Time complexityO(n2)O(n^2)O(n)O(n)
MemoryMinimalCache grows with context
Practical speedupBaseline10-1000× faster

Key insight: K and V for past tokens are deterministic—compute once, reuse forever.

The tradeoff: Memory for speed. Almost always worth it.

Next: Interpretability

We’ve built, trained, and optimized a transformer. But what has it actually learned? The final notebook explores mechanistic interpretability—techniques for peeking inside the black box to understand what patterns and circuits the model has discovered.