Why Multiple Heads?¶
We just computed attention separately for two heads. Each head has its own Q, K, V projections, so each head learns to look at the data differently.
Think of it like having multiple experts examine the same text:
One head might specialize in syntactic relationships (subject-verb, noun-adjective)
Another might focus on semantic relationships (related concepts, coreference)
Another might track positional patterns (nearby words, document structure)
In practice, what heads learn is emergent—we don’t design them for specific tasks. But research has shown that different heads genuinely specialize. Some learn to track quotation marks. Others focus on rare words. Others capture long-range dependencies.
The multi-head mechanism lets the model attend to information from different representation subspaces simultaneously. It’s one of the key innovations that made transformers so effective.
The Multi-Head Process¶
We have attention outputs from 2 heads, each with shape [5, 8]. To get back to d_model = 16:
Concatenate the head outputs:
[5, 8]and[5, 8]→[5, 16]Project through output matrix :
[5, 16]@[16, 16]→[5, 16]
The concatenation is straightforward—just stack the vectors side by side.
The output projection through is important: it lets the model mix information across heads. Without it, the heads would remain completely independent. With it, the model can learn combinations like “if head 0 sees X and head 1 sees Y, output Z.”
import random
import math
random.seed(42)
VOCAB_SIZE = 6
D_MODEL = 16
MAX_SEQ_LEN = 5
NUM_HEADS = 2
D_K = D_MODEL // NUM_HEADS # 8
TOKEN_NAMES = ["<PAD>", "<BOS>", "<EOS>", "I", "like", "transformers"]# Helper functions
def random_vector(size, scale=0.1):
return [random.gauss(0, scale) for _ in range(size)]
def random_matrix(rows, cols, scale=0.1):
return [[random.gauss(0, scale) for _ in range(cols)] for _ in range(rows)]
def add_vectors(v1, v2):
return [a + b for a, b in zip(v1, v2)]
def matmul(A, B):
m, n, p = len(A), len(A[0]), len(B[0])
return [[sum(A[i][k] * B[k][j] for k in range(n)) for j in range(p)] for i in range(m)]
def transpose(A):
rows, cols = len(A), len(A[0])
return [[A[i][j] for i in range(rows)] for j in range(cols)]
def softmax(vec):
finite_vals = [v for v in vec if v != float('-inf')]
max_val = max(finite_vals) if finite_vals else 0
exp_vec = [math.exp(v - max_val) if v != float('-inf') else 0.0 for v in vec]
total = sum(exp_vec)
return [e / total for e in exp_vec]
def format_vector(vec, decimals=4):
return "[" + ", ".join([f"{v:7.{decimals}f}" for v in vec]) + "]"# Recreate everything from previous notebooks
E_token = [random_vector(D_MODEL) for _ in range(VOCAB_SIZE)]
E_pos = [random_vector(D_MODEL) for _ in range(MAX_SEQ_LEN)]
tokens = [1, 3, 4, 5, 2]
seq_len = len(tokens)
X = [add_vectors(E_token[tokens[i]], E_pos[i]) for i in range(seq_len)]
W_Q = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
W_K = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
W_V = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
Q_all = [matmul(X, W_Q[h]) for h in range(NUM_HEADS)]
K_all = [matmul(X, W_K[h]) for h in range(NUM_HEADS)]
V_all = [matmul(X, W_V[h]) for h in range(NUM_HEADS)]
def attention(Q, K, V):
seq_len, d_k = len(Q), len(Q[0])
scale = math.sqrt(d_k)
scores = matmul(Q, transpose(K))
scaled = [[s / scale for s in row] for row in scores]
for i in range(seq_len):
for j in range(seq_len):
if j > i:
scaled[i][j] = float('-inf')
weights = [softmax(row) for row in scaled]
return matmul(weights, V), weights
attention_output_all = []
for h in range(NUM_HEADS):
output, _ = attention(Q_all[h], K_all[h], V_all[h])
attention_output_all.append(output)
print(f"Recreated attention outputs from previous notebooks")
print(f"Head 0 output shape: [{len(attention_output_all[0])}, {len(attention_output_all[0][0])}]")
print(f"Head 1 output shape: [{len(attention_output_all[1])}, {len(attention_output_all[1][0])}]")Recreated attention outputs from previous notebooks
Head 0 output shape: [5, 8]
Head 1 output shape: [5, 8]
Step 1: Concatenate Head Outputs¶
Each head produced an output of shape [5, 8]. We concatenate them along the last dimension:
For each position , we take the 8-dimensional vector from head 0 and the 8-dimensional vector from head 1, and stick them together to get a 16-dimensional vector.
Result shape: [5, 16]
# Concatenate head outputs
concat_output = []
for i in range(seq_len):
# Concatenate head 0 and head 1 outputs for position i
concat_row = attention_output_all[0][i] + attention_output_all[1][i]
concat_output.append(concat_row)
print("Concatenated Attention Outputs")
print(f"Shape: [{seq_len}, {D_MODEL}]")
print()
for i, row in enumerate(concat_output):
print(f" pos {i} ({TOKEN_NAMES[tokens[i]]:12s}): {format_vector(row)}")Concatenated Attention Outputs
Shape: [5, 16]
pos 0 (<BOS> ): [ 0.0800, 0.0257, -0.0117, -0.1056, 0.0339, -0.0891, -0.0083, -0.0737, 0.0107, -0.0291, -0.0100, -0.0312, 0.0214, 0.0372, 0.0105, 0.0279]
pos 1 (I ): [ 0.0683, 0.0368, -0.0263, -0.0574, 0.0152, -0.0174, -0.0084, -0.0760, -0.0199, -0.0151, 0.0026, 0.0107, 0.0091, -0.0204, -0.0320, -0.0193]
pos 2 (like ): [ 0.0247, 0.0789, 0.0074, -0.0635, 0.0180, -0.0098, -0.0184, -0.0173, -0.0320, -0.0102, 0.0178, -0.0153, 0.0433, 0.0026, 0.0002, -0.0198]
pos 3 (transformers): [ 0.0254, 0.0511, -0.0182, -0.0322, 0.0103, -0.0126, -0.0282, 0.0018, -0.0111, -0.0085, 0.0093, 0.0101, 0.0440, 0.0237, 0.0056, -0.0311]
pos 4 (<EOS> ): [ 0.0325, 0.0367, -0.0202, -0.0262, 0.0188, -0.0040, -0.0321, 0.0167, -0.0119, -0.0013, -0.0069, 0.0016, 0.0480, 0.0233, 0.0096, -0.0121]
# Show the concatenation for position 0 in detail
print("Detailed: Concatenation for position 0 (<BOS>)")
print("=" * 70)
print()
print(f"Head 0 output (8 dims): {format_vector(attention_output_all[0][0])}")
print()
print(f"Head 1 output (8 dims): {format_vector(attention_output_all[1][0])}")
print()
print(f"Concatenated (16 dims): {format_vector(concat_output[0])}")
print()
print("First 8 elements come from head 0, last 8 from head 1.")Detailed: Concatenation for position 0 (<BOS>)
======================================================================
Head 0 output (8 dims): [ 0.0800, 0.0257, -0.0117, -0.1056, 0.0339, -0.0891, -0.0083, -0.0737]
Head 1 output (8 dims): [ 0.0107, -0.0291, -0.0100, -0.0312, 0.0214, 0.0372, 0.0105, 0.0279]
Concatenated (16 dims): [ 0.0800, 0.0257, -0.0117, -0.1056, 0.0339, -0.0891, -0.0083, -0.0737, 0.0107, -0.0291, -0.0100, -0.0312, 0.0214, 0.0372, 0.0105, 0.0279]
First 8 elements come from head 0, last 8 from head 1.
Step 2: Output Projection¶
The concatenated output is [5, 16]. We project it through the output weight matrix :
Where has shape [d_model, d_model] = [16, 16].
Why this projection?
Without , the heads would be completely separate. The first 8 dimensions would always come from head 0, the last 8 from head 1. There’s no interaction.
With , each output dimension becomes a learned combination of all dimensions from all heads. The model can learn:
“If head 0 found a verb and head 1 found the subject, strengthen the connection”
“Suppress noise when heads disagree”
“Combine syntactic and semantic signals into a unified representation”
This mixing is crucial for multi-head attention’s power.
# Initialize output projection matrix
W_O = random_matrix(D_MODEL, D_MODEL) # [16, 16]
print(f"Output Projection Matrix W_O")
print(f"Shape: [{D_MODEL}, {D_MODEL}]")
print()
print("(16×16 = 256 learnable parameters for mixing heads)")Output Projection Matrix W_O
Shape: [16, 16]
(16×16 = 256 learnable parameters for mixing heads)
# Apply output projection: concat @ W_O^T
W_O_T = transpose(W_O)
multi_head_output = matmul(concat_output, W_O_T)
print("Multi-Head Attention Output")
print(f"Shape: [{seq_len}, {D_MODEL}] @ [{D_MODEL}, {D_MODEL}] = [{seq_len}, {D_MODEL}]")
print()
for i, row in enumerate(multi_head_output):
print(f" pos {i} ({TOKEN_NAMES[tokens[i]]:12s}): {format_vector(row)}")Multi-Head Attention Output
Shape: [5, 16] @ [16, 16] = [5, 16]
pos 0 (<BOS> ): [ 0.0334, 0.0033, -0.0041, -0.0073, 0.0185, 0.0074, 0.0169, 0.0107, 0.0277, 0.0060, 0.0222, 0.0241, 0.0074, 0.0067, -0.0067, 0.0063]
pos 1 (I ): [ 0.0269, 0.0066, 0.0113, -0.0154, 0.0114, 0.0032, -0.0065, -0.0108, 0.0190, -0.0091, 0.0180, 0.0097, -0.0075, 0.0061, -0.0079, 0.0110]
pos 2 (like ): [ 0.0085, 0.0086, 0.0159, -0.0177, 0.0026, 0.0205, -0.0057, -0.0055, 0.0059, -0.0043, 0.0007, -0.0053, 0.0075, -0.0012, -0.0043, -0.0016]
pos 3 (transformers): [ 0.0064, -0.0084, 0.0092, -0.0173, 0.0068, 0.0119, -0.0100, -0.0027, 0.0027, -0.0073, 0.0036, -0.0076, 0.0022, -0.0070, -0.0095, -0.0070]
pos 4 (<EOS> ): [ 0.0039, -0.0068, 0.0098, -0.0136, 0.0031, 0.0090, -0.0086, -0.0027, -0.0003, -0.0044, -0.0029, -0.0062, 0.0060, -0.0048, -0.0036, -0.0115]
# Detailed calculation for one element
print("Detailed: Computing output[0][0]")
print("=" * 70)
print()
print("output[0][0] = concat[0] · W_O[:, 0]")
print()
print(f"concat[0] (16 dims): {format_vector(concat_output[0])}")
print()
col_0 = [W_O[i][0] for i in range(D_MODEL)]
print(f"W_O[:, 0] (16 dims): {format_vector(col_0)}")
print()
result = sum(concat_output[0][j] * col_0[j] for j in range(D_MODEL))
print(f"Dot product = {result:.6f}")
print(f"Actual output[0][0] = {multi_head_output[0][0]:.6f}")Detailed: Computing output[0][0]
======================================================================
output[0][0] = concat[0] · W_O[:, 0]
concat[0] (16 dims): [ 0.0800, 0.0257, -0.0117, -0.1056, 0.0339, -0.0891, -0.0083, -0.0737, 0.0107, -0.0291, -0.0100, -0.0312, 0.0214, 0.0372, 0.0105, 0.0279]
W_O[:, 0] (16 dims): [-0.0262, -0.0621, 0.1544, -0.1022, 0.0636, 0.1289, -0.0064, -0.1005, -0.0360, -0.0405, 0.0004, 0.1525, -0.0449, -0.0366, 0.0846, 0.0188]
Dot product = -0.001460
Actual output[0][0] = 0.033421
Comparing Input and Output¶
Let’s compare what went into multi-head attention (the original embeddings ) with what came out.
The output has the same shape as the input [5, 16], but each vector now incorporates information from other positions via attention. The representation for “like” now contains information about “I” and “<BOS>”—it’s no longer just about the word “like” in isolation.
print("Comparison: Input vs Multi-Head Attention Output")
print("=" * 80)
print()
for i in range(seq_len):
print(f"Position {i} ({TOKEN_NAMES[tokens[i]]})")
print(f" Input X: {format_vector(X[i])}")
print(f" Output: {format_vector(multi_head_output[i])}")
# Compute magnitude change
mag_in = math.sqrt(sum(x**2 for x in X[i]))
mag_out = math.sqrt(sum(x**2 for x in multi_head_output[i]))
print(f" Magnitude: {mag_in:.4f} → {mag_out:.4f}")
print()Comparison: Input vs Multi-Head Attention Output
================================================================================
Position 0 (<BOS>)
Input X: [ 0.1473, 0.1281, 0.1995, -0.0465, 0.2125, -0.1338, -0.0829, -0.0638, 0.0722, 0.1183, 0.1193, 0.0937, -0.1594, -0.0402, 0.1124, -0.2064]
Output: [ 0.0334, 0.0033, -0.0041, -0.0073, 0.0185, 0.0074, 0.0169, 0.0107, 0.0277, 0.0060, 0.0222, 0.0241, 0.0074, 0.0067, -0.0067, 0.0063]
Magnitude: 0.5278 → 0.0637
Position 1 (I)
Input X: [-0.1254, -0.0720, 0.1255, -0.0556, -0.0678, 0.3698, -0.1265, -0.1463, 0.0866, 0.0181, 0.0726, -0.0374, 0.2312, -0.0091, 0.0860, -0.0251]
Output: [ 0.0269, 0.0066, 0.0113, -0.0154, 0.0114, 0.0032, -0.0065, -0.0108, 0.0190, -0.0091, 0.0180, 0.0097, -0.0075, 0.0061, -0.0079, 0.0110]
Magnitude: 0.5427 → 0.0507
Position 2 (like)
Input X: [ 0.2319, -0.2747, -0.0089, 0.0576, 0.1430, -0.0957, 0.1571, 0.2913, 0.2154, 0.0103, -0.0510, -0.1353, -0.0296, -0.0371, -0.0262, 0.2770]
Output: [ 0.0085, 0.0086, 0.0159, -0.0177, 0.0026, 0.0205, -0.0057, -0.0055, 0.0059, -0.0043, 0.0007, -0.0053, 0.0075, -0.0012, -0.0043, -0.0016]
Magnitude: 0.6472 → 0.0369
Position 3 (transformers)
Input X: [-0.1334, 0.0027, -0.3410, -0.1478, -0.0307, 0.1240, 0.2642, -0.0063, -0.0856, 0.0626, 0.1602, 0.1385, -0.0427, 0.0122, 0.0991, 0.1081]
Output: [ 0.0064, -0.0084, 0.0092, -0.0173, 0.0068, 0.0119, -0.0100, -0.0027, 0.0027, -0.0073, 0.0036, -0.0076, 0.0022, -0.0070, -0.0095, -0.0070]
Magnitude: 0.5671 → 0.0334
Position 4 (<EOS>)
Input X: [-0.1346, 0.0002, -0.0629, 0.3029, 0.0908, -0.1515, 0.0959, 0.0481, 0.0032, 0.0225, 0.1310, 0.0306, -0.1088, 0.0649, 0.0880, -0.0130]
Output: [ 0.0039, -0.0068, 0.0098, -0.0136, 0.0031, 0.0090, -0.0086, -0.0027, -0.0003, -0.0044, -0.0029, -0.0062, 0.0060, -0.0048, -0.0036, -0.0115]
Magnitude: 0.4462 → 0.0280
Parameter Count¶
Multi-head attention has a lot of parameters:
| Component | Shape | Count |
|---|---|---|
| (per head) | [16, 8] × 2 heads | 256 |
| (per head) | [16, 8] × 2 heads | 256 |
| (per head) | [16, 8] × 2 heads | 256 |
| [16, 16] | 256 | |
| Total | 1,024 |
That’s 1,024 parameters just for attention in one transformer block. In GPT-3 with d_model=12,288 and 96 heads, the attention parameters per layer are about 150 million.
What Multi-Head Attention Accomplished¶
We started with embeddings where each position was independent. After multi-head attention:
Each position’s representation incorporates information from previous positions
Two different “perspectives” (heads) contributed to this mixing
The output projection combined these perspectives into a unified representation
The model hasn’t changed the number of positions—we still have 5 vectors of dimension 16. But the content of those vectors is now context-dependent. This is the fundamental innovation of attention.
What’s Next¶
Multi-head attention is done. But we’re not finished with the transformer block.
Next comes the feed-forward network (FFN)—a simple two-layer neural network applied to each position independently. Where attention lets positions communicate, the FFN lets each position process the information it’s gathered.
# Store for next notebook
multi_head_data = {
'X': X,
'tokens': tokens,
'multi_head_output': multi_head_output,
'W_O': W_O,
'concat_output': concat_output
}
print("Multi-head attention complete. Ready for feed-forward network.")Multi-head attention complete. Ready for feed-forward network.