Why Multiple Heads?¶
We just computed attention separately for two heads. Each head has its own Q, K, V projections, so each head learns to look at the data differently.
Think of it like having multiple experts examine the same text:
One head might specialize in syntactic relationships (subject-verb, noun-adjective)
Another might focus on semantic relationships (related concepts, coreference)
Another might track positional patterns (nearby words, document structure)
In practice, what heads learn is emergent; we don’t design them for specific tasks. But research has shown that different heads genuinely specialize. Some learn to track quotation marks. Others focus on rare words. Others capture long-range dependencies.
The multi-head mechanism lets the model attend to information from different representation subspaces simultaneously. It’s one of the key innovations that made transformers so effective.
The Multi-Head Process¶
We have attention outputs from 2 heads, each with shape [5, 8]. To get back to d_model = 16:
Concatenate the head outputs:
[5, 8]and[5, 8]→[5, 16]Project through output matrix :
[5, 16]@[16, 16]→[5, 16]
The concatenation is straightforward. Just stack the vectors side by side.
The output projection through is important: it lets the model mix information across heads. Without it, the heads would remain completely independent. With it, the model can learn combinations like “if head 0 sees X and head 1 sees Y, output Z.”
import random
import math
random.seed(42)
VOCAB_SIZE = 6
D_MODEL = 16
MAX_SEQ_LEN = 5
NUM_HEADS = 2
D_K = D_MODEL // NUM_HEADS # 8
TOKEN_NAMES = ["<PAD>", "<BOS>", "<EOS>", "I", "like", "transformers"]# Helper functions
def random_vector(size, scale=0.1):
return [random.gauss(0, scale) for _ in range(size)]
def random_matrix(rows, cols, scale=0.1):
return [[random.gauss(0, scale) for _ in range(cols)] for _ in range(rows)]
def add_vectors(v1, v2):
return [a + b for a, b in zip(v1, v2)]
def matmul(A, B):
m, n, p = len(A), len(A[0]), len(B[0])
return [[sum(A[i][k] * B[k][j] for k in range(n)) for j in range(p)] for i in range(m)]
def transpose(A):
rows, cols = len(A), len(A[0])
return [[A[i][j] for i in range(rows)] for j in range(cols)]
def softmax(vec):
finite_vals = [v for v in vec if v != float('-inf')]
max_val = max(finite_vals) if finite_vals else 0
exp_vec = [math.exp(v - max_val) if v != float('-inf') else 0.0 for v in vec]
total = sum(exp_vec)
return [e / total for e in exp_vec]
def format_vector(vec, decimals=4):
return "[" + ", ".join([f"{v:7.{decimals}f}" for v in vec]) + "]"# Recreate everything from previous notebooks
E_token = [random_vector(D_MODEL) for _ in range(VOCAB_SIZE)]
E_pos = [random_vector(D_MODEL) for _ in range(MAX_SEQ_LEN)]
tokens = [1, 3, 4, 5, 2]
seq_len = len(tokens)
X = [add_vectors(E_token[tokens[i]], E_pos[i]) for i in range(seq_len)]
W_Q = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
W_K = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
W_V = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
Q_all = [matmul(X, W_Q[h]) for h in range(NUM_HEADS)]
K_all = [matmul(X, W_K[h]) for h in range(NUM_HEADS)]
V_all = [matmul(X, W_V[h]) for h in range(NUM_HEADS)]
def attention(Q, K, V):
seq_len, d_k = len(Q), len(Q[0])
scale = math.sqrt(d_k)
scores = matmul(Q, transpose(K))
scaled = [[s / scale for s in row] for row in scores]
for i in range(seq_len):
for j in range(seq_len):
if j > i:
scaled[i][j] = float('-inf')
weights = [softmax(row) for row in scaled]
return matmul(weights, V), weights
attention_output_all = []
for h in range(NUM_HEADS):
output, _ = attention(Q_all[h], K_all[h], V_all[h])
attention_output_all.append(output)
print(f"Recreated attention outputs from previous notebooks")
print(f"Head 0 output shape: [{len(attention_output_all[0])}, {len(attention_output_all[0][0])}]")
print(f"Head 1 output shape: [{len(attention_output_all[1])}, {len(attention_output_all[1][0])}]")Recreated attention outputs from previous notebooks
Head 0 output shape: [5, 8]
Head 1 output shape: [5, 8]
Step 1: Concatenate Head Outputs¶
Each head produced an output of shape [5, 8]. We concatenate them along the last dimension:
For each position , we take the 8-dimensional vector from head 0 and the 8-dimensional vector from head 1, and stick them together to get a 16-dimensional vector.
Result shape: [5, 16]
# Concatenate head outputs
concat_output = []
for i in range(seq_len):
# Concatenate head 0 and head 1 outputs for position i
concat_row = attention_output_all[0][i] + attention_output_all[1][i]
concat_output.append(concat_row)
print("Concatenated Attention Outputs")
print(f"Shape: [{seq_len}, {D_MODEL}]")
print()
for i, row in enumerate(concat_output):
print(f" pos {i} ({TOKEN_NAMES[tokens[i]]:12s}): {format_vector(row)}")Concatenated Attention Outputs
Shape: [5, 16]
pos 0 (<BOS> ): [ 0.0800, 0.0257, -0.0117, -0.1056, 0.0339, -0.0891, -0.0083, -0.0737, 0.0107, -0.0291, -0.0100, -0.0312, 0.0214, 0.0372, 0.0105, 0.0279]
pos 1 (I ): [ 0.0683, 0.0368, -0.0263, -0.0574, 0.0152, -0.0174, -0.0084, -0.0760, -0.0199, -0.0151, 0.0026, 0.0107, 0.0091, -0.0204, -0.0320, -0.0193]
pos 2 (like ): [ 0.0247, 0.0789, 0.0074, -0.0635, 0.0180, -0.0098, -0.0184, -0.0173, -0.0320, -0.0102, 0.0178, -0.0153, 0.0433, 0.0026, 0.0002, -0.0198]
pos 3 (transformers): [ 0.0254, 0.0511, -0.0182, -0.0322, 0.0103, -0.0126, -0.0282, 0.0018, -0.0111, -0.0085, 0.0093, 0.0101, 0.0440, 0.0237, 0.0056, -0.0311]
pos 4 (<EOS> ): [ 0.0325, 0.0367, -0.0202, -0.0262, 0.0188, -0.0040, -0.0321, 0.0167, -0.0119, -0.0013, -0.0069, 0.0016, 0.0480, 0.0233, 0.0096, -0.0121]
# Show the concatenation for position 0 in detail
print("Detailed: Concatenation for position 0 (<BOS>)")
print("=" * 70)
print()
print(f"Head 0 output (8 dims): {format_vector(attention_output_all[0][0])}")
print()
print(f"Head 1 output (8 dims): {format_vector(attention_output_all[1][0])}")
print()
print(f"Concatenated (16 dims): {format_vector(concat_output[0])}")
print()
print("First 8 elements come from head 0, last 8 from head 1.")Detailed: Concatenation for position 0 (<BOS>)
======================================================================
Head 0 output (8 dims): [ 0.0800, 0.0257, -0.0117, -0.1056, 0.0339, -0.0891, -0.0083, -0.0737]
Head 1 output (8 dims): [ 0.0107, -0.0291, -0.0100, -0.0312, 0.0214, 0.0372, 0.0105, 0.0279]
Concatenated (16 dims): [ 0.0800, 0.0257, -0.0117, -0.1056, 0.0339, -0.0891, -0.0083, -0.0737, 0.0107, -0.0291, -0.0100, -0.0312, 0.0214, 0.0372, 0.0105, 0.0279]
First 8 elements come from head 0, last 8 from head 1.
Step 2: Output Projection¶
The concatenated output is [5, 16]. We project it through the output weight matrix :
Where has shape [d_model, d_model] = [16, 16].
Why this projection?
Without , the heads would be completely separate. The first 8 dimensions would always come from head 0, the last 8 from head 1. There’s no interaction.
With , each output dimension becomes a learned combination of all dimensions from all heads. The model can learn:
“If head 0 found a verb and head 1 found the subject, strengthen the connection”
“Suppress noise when heads disagree”
“Combine syntactic and semantic signals into a unified representation”
This mixing is crucial for multi-head attention’s power.
# Initialize output projection matrix
W_O = random_matrix(D_MODEL, D_MODEL) # [16, 16]
print(f"Output Projection Matrix W_O")
print(f"Shape: [{D_MODEL}, {D_MODEL}]")
print()
print("(16×16 = 256 learnable parameters for mixing heads)")Output Projection Matrix W_O
Shape: [16, 16]
(16×16 = 256 learnable parameters for mixing heads)
# Apply output projection: concat @ W_O^T
W_O_T = transpose(W_O)
multi_head_output = matmul(concat_output, W_O_T)
print("Multi-Head Attention Output")
print(f"Shape: [{seq_len}, {D_MODEL}] @ [{D_MODEL}, {D_MODEL}] = [{seq_len}, {D_MODEL}]")
print()
for i, row in enumerate(multi_head_output):
print(f" pos {i} ({TOKEN_NAMES[tokens[i]]:12s}): {format_vector(row)}")Multi-Head Attention Output
Shape: [5, 16] @ [16, 16] = [5, 16]
pos 0 (<BOS> ): [ 0.0334, 0.0033, -0.0041, -0.0073, 0.0185, 0.0074, 0.0169, 0.0107, 0.0277, 0.0060, 0.0222, 0.0241, 0.0074, 0.0067, -0.0067, 0.0063]
pos 1 (I ): [ 0.0269, 0.0066, 0.0113, -0.0154, 0.0114, 0.0032, -0.0065, -0.0108, 0.0190, -0.0091, 0.0180, 0.0097, -0.0075, 0.0061, -0.0079, 0.0110]
pos 2 (like ): [ 0.0085, 0.0086, 0.0159, -0.0177, 0.0026, 0.0205, -0.0057, -0.0055, 0.0059, -0.0043, 0.0007, -0.0053, 0.0075, -0.0012, -0.0043, -0.0016]
pos 3 (transformers): [ 0.0064, -0.0084, 0.0092, -0.0173, 0.0068, 0.0119, -0.0100, -0.0027, 0.0027, -0.0073, 0.0036, -0.0076, 0.0022, -0.0070, -0.0095, -0.0070]
pos 4 (<EOS> ): [ 0.0039, -0.0068, 0.0098, -0.0136, 0.0031, 0.0090, -0.0086, -0.0027, -0.0003, -0.0044, -0.0029, -0.0062, 0.0060, -0.0048, -0.0036, -0.0115]
# Detailed calculation for one element
print("Detailed: Computing output[0][0]")
print("=" * 70)
print()
print("output[0][0] = concat[0] · W_O[:, 0]")
print()
print(f"concat[0] (16 dims): {format_vector(concat_output[0])}")
print()
col_0 = [W_O[i][0] for i in range(D_MODEL)]
print(f"W_O[:, 0] (16 dims): {format_vector(col_0)}")
print()
result = sum(concat_output[0][j] * col_0[j] for j in range(D_MODEL))
print(f"Dot product = {result:.6f}")
print(f"Actual output[0][0] = {multi_head_output[0][0]:.6f}")Detailed: Computing output[0][0]
======================================================================
output[0][0] = concat[0] · W_O[:, 0]
concat[0] (16 dims): [ 0.0800, 0.0257, -0.0117, -0.1056, 0.0339, -0.0891, -0.0083, -0.0737, 0.0107, -0.0291, -0.0100, -0.0312, 0.0214, 0.0372, 0.0105, 0.0279]
W_O[:, 0] (16 dims): [-0.0262, -0.0621, 0.1544, -0.1022, 0.0636, 0.1289, -0.0064, -0.1005, -0.0360, -0.0405, 0.0004, 0.1525, -0.0449, -0.0366, 0.0846, 0.0188]
Dot product = -0.001460
Actual output[0][0] = 0.033421
Comparing Input and Output¶
Let’s compare what went into multi-head attention (the original embeddings ) with what came out.
The output has the same shape as the input [5, 16], but each vector now incorporates information from other positions via attention. The representation for “like” now contains information about “I” and “<BOS>”. it’s no longer just about the word “like” in isolation.
print("Comparison: Input vs Multi-Head Attention Output")
print("=" * 80)
print()
for i in range(seq_len):
print(f"Position {i} ({TOKEN_NAMES[tokens[i]]})")
print(f" Input X: {format_vector(X[i])}")
print(f" Output: {format_vector(multi_head_output[i])}")
# Compute magnitude change
mag_in = math.sqrt(sum(x**2 for x in X[i]))
mag_out = math.sqrt(sum(x**2 for x in multi_head_output[i]))
print(f" Magnitude: {mag_in:.4f} → {mag_out:.4f}")
print()Comparison: Input vs Multi-Head Attention Output
================================================================================
Position 0 (<BOS>)
Input X: [ 0.1473, 0.1281, 0.1995, -0.0465, 0.2125, -0.1338, -0.0829, -0.0638, 0.0722, 0.1183, 0.1193, 0.0937, -0.1594, -0.0402, 0.1124, -0.2064]
Output: [ 0.0334, 0.0033, -0.0041, -0.0073, 0.0185, 0.0074, 0.0169, 0.0107, 0.0277, 0.0060, 0.0222, 0.0241, 0.0074, 0.0067, -0.0067, 0.0063]
Magnitude: 0.5278 → 0.0637
Position 1 (I)
Input X: [-0.1254, -0.0720, 0.1255, -0.0556, -0.0678, 0.3698, -0.1265, -0.1463, 0.0866, 0.0181, 0.0726, -0.0374, 0.2312, -0.0091, 0.0860, -0.0251]
Output: [ 0.0269, 0.0066, 0.0113, -0.0154, 0.0114, 0.0032, -0.0065, -0.0108, 0.0190, -0.0091, 0.0180, 0.0097, -0.0075, 0.0061, -0.0079, 0.0110]
Magnitude: 0.5427 → 0.0507
Position 2 (like)
Input X: [ 0.2319, -0.2747, -0.0089, 0.0576, 0.1430, -0.0957, 0.1571, 0.2913, 0.2154, 0.0103, -0.0510, -0.1353, -0.0296, -0.0371, -0.0262, 0.2770]
Output: [ 0.0085, 0.0086, 0.0159, -0.0177, 0.0026, 0.0205, -0.0057, -0.0055, 0.0059, -0.0043, 0.0007, -0.0053, 0.0075, -0.0012, -0.0043, -0.0016]
Magnitude: 0.6472 → 0.0369
Position 3 (transformers)
Input X: [-0.1334, 0.0027, -0.3410, -0.1478, -0.0307, 0.1240, 0.2642, -0.0063, -0.0856, 0.0626, 0.1602, 0.1385, -0.0427, 0.0122, 0.0991, 0.1081]
Output: [ 0.0064, -0.0084, 0.0092, -0.0173, 0.0068, 0.0119, -0.0100, -0.0027, 0.0027, -0.0073, 0.0036, -0.0076, 0.0022, -0.0070, -0.0095, -0.0070]
Magnitude: 0.5671 → 0.0334
Position 4 (<EOS>)
Input X: [-0.1346, 0.0002, -0.0629, 0.3029, 0.0908, -0.1515, 0.0959, 0.0481, 0.0032, 0.0225, 0.1310, 0.0306, -0.1088, 0.0649, 0.0880, -0.0130]
Output: [ 0.0039, -0.0068, 0.0098, -0.0136, 0.0031, 0.0090, -0.0086, -0.0027, -0.0003, -0.0044, -0.0029, -0.0062, 0.0060, -0.0048, -0.0036, -0.0115]
Magnitude: 0.4462 → 0.0280
Parameter Count¶
Multi-head attention has a lot of parameters:
| Component | Shape | Count |
|---|---|---|
| (per head) | [16, 8] × 2 heads | 256 |
| (per head) | [16, 8] × 2 heads | 256 |
| (per head) | [16, 8] × 2 heads | 256 |
| [16, 16] | 256 | |
| Total | 1,024 |
That’s 1,024 parameters just for attention in one transformer block. In GPT-3 with d_model=12,288 and 96 heads, the attention parameters per layer are about 150 million.
What Multi-Head Attention Accomplished¶
We started with embeddings where each position was independent. After multi-head attention:
Each position’s representation incorporates information from previous positions
Two different “perspectives” (heads) contributed to this mixing
The output projection combined these perspectives into a unified representation
The model hasn’t changed the number of positions. we still have 5 vectors of dimension 16. But the content of those vectors is now context-dependent. This is the fundamental innovation of attention.
What’s Next¶
Multi-head attention is done. But we’re not finished with the transformer block.
Next comes the feed-forward network (FFN). A simple two-layer neural network applied to each position independently. Where attention lets positions communicate, the FFN lets each position process the information it’s gathered.
# Store for next notebook
multi_head_data = {
'X': X,
'tokens': tokens,
'multi_head_output': multi_head_output,
'W_O': W_O,
'concat_output': concat_output
}
print("Multi-head attention complete. Ready for feed-forward network.")Multi-head attention complete. Ready for feed-forward network.
