Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Multi-Head Attention

Why Multiple Heads?

We just computed attention separately for two heads. Each head has its own Q, K, V projections, so each head learns to look at the data differently.

Think of it like having multiple experts examine the same text:

  • One head might specialize in syntactic relationships (subject-verb, noun-adjective)

  • Another might focus on semantic relationships (related concepts, coreference)

  • Another might track positional patterns (nearby words, document structure)

In practice, what heads learn is emergent—we don’t design them for specific tasks. But research has shown that different heads genuinely specialize. Some learn to track quotation marks. Others focus on rare words. Others capture long-range dependencies.

The multi-head mechanism lets the model attend to information from different representation subspaces simultaneously. It’s one of the key innovations that made transformers so effective.

The Multi-Head Process

We have attention outputs from 2 heads, each with shape [5, 8]. To get back to d_model = 16:

  1. Concatenate the head outputs: [5, 8] and [5, 8][5, 16]

  2. Project through output matrix WOW_O: [5, 16] @ [16, 16][5, 16]

MultiHead(X)=Concat(head0,head1)WO\text{MultiHead}(X) = \text{Concat}(\text{head}_0, \text{head}_1) \cdot W_O

The concatenation is straightforward—just stack the vectors side by side.

The output projection through WOW_O is important: it lets the model mix information across heads. Without it, the heads would remain completely independent. With it, the model can learn combinations like “if head 0 sees X and head 1 sees Y, output Z.”

import random
import math

random.seed(42)

VOCAB_SIZE = 6
D_MODEL = 16
MAX_SEQ_LEN = 5
NUM_HEADS = 2
D_K = D_MODEL // NUM_HEADS  # 8

TOKEN_NAMES = ["<PAD>", "<BOS>", "<EOS>", "I", "like", "transformers"]
# Helper functions
def random_vector(size, scale=0.1):
    return [random.gauss(0, scale) for _ in range(size)]

def random_matrix(rows, cols, scale=0.1):
    return [[random.gauss(0, scale) for _ in range(cols)] for _ in range(rows)]

def add_vectors(v1, v2):
    return [a + b for a, b in zip(v1, v2)]

def matmul(A, B):
    m, n, p = len(A), len(A[0]), len(B[0])
    return [[sum(A[i][k] * B[k][j] for k in range(n)) for j in range(p)] for i in range(m)]

def transpose(A):
    rows, cols = len(A), len(A[0])
    return [[A[i][j] for i in range(rows)] for j in range(cols)]

def softmax(vec):
    finite_vals = [v for v in vec if v != float('-inf')]
    max_val = max(finite_vals) if finite_vals else 0
    exp_vec = [math.exp(v - max_val) if v != float('-inf') else 0.0 for v in vec]
    total = sum(exp_vec)
    return [e / total for e in exp_vec]

def format_vector(vec, decimals=4):
    return "[" + ", ".join([f"{v:7.{decimals}f}" for v in vec]) + "]"
# Recreate everything from previous notebooks
E_token = [random_vector(D_MODEL) for _ in range(VOCAB_SIZE)]
E_pos = [random_vector(D_MODEL) for _ in range(MAX_SEQ_LEN)]
tokens = [1, 3, 4, 5, 2]
seq_len = len(tokens)
X = [add_vectors(E_token[tokens[i]], E_pos[i]) for i in range(seq_len)]

W_Q = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
W_K = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
W_V = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]

Q_all = [matmul(X, W_Q[h]) for h in range(NUM_HEADS)]
K_all = [matmul(X, W_K[h]) for h in range(NUM_HEADS)]
V_all = [matmul(X, W_V[h]) for h in range(NUM_HEADS)]

def attention(Q, K, V):
    seq_len, d_k = len(Q), len(Q[0])
    scale = math.sqrt(d_k)
    scores = matmul(Q, transpose(K))
    scaled = [[s / scale for s in row] for row in scores]
    for i in range(seq_len):
        for j in range(seq_len):
            if j > i:
                scaled[i][j] = float('-inf')
    weights = [softmax(row) for row in scaled]
    return matmul(weights, V), weights

attention_output_all = []
for h in range(NUM_HEADS):
    output, _ = attention(Q_all[h], K_all[h], V_all[h])
    attention_output_all.append(output)

print(f"Recreated attention outputs from previous notebooks")
print(f"Head 0 output shape: [{len(attention_output_all[0])}, {len(attention_output_all[0][0])}]")
print(f"Head 1 output shape: [{len(attention_output_all[1])}, {len(attention_output_all[1][0])}]")
Recreated attention outputs from previous notebooks
Head 0 output shape: [5, 8]
Head 1 output shape: [5, 8]

Step 1: Concatenate Head Outputs

Each head produced an output of shape [5, 8]. We concatenate them along the last dimension:

concat[i]=[head0[i]    head1[i]]\text{concat}[i] = [\text{head}_0[i] \;\|\; \text{head}_1[i]]

For each position ii, we take the 8-dimensional vector from head 0 and the 8-dimensional vector from head 1, and stick them together to get a 16-dimensional vector.

Result shape: [5, 16]

# Concatenate head outputs
concat_output = []
for i in range(seq_len):
    # Concatenate head 0 and head 1 outputs for position i
    concat_row = attention_output_all[0][i] + attention_output_all[1][i]
    concat_output.append(concat_row)

print("Concatenated Attention Outputs")
print(f"Shape: [{seq_len}, {D_MODEL}]")
print()
for i, row in enumerate(concat_output):
    print(f"  pos {i} ({TOKEN_NAMES[tokens[i]]:12s}): {format_vector(row)}")
Concatenated Attention Outputs
Shape: [5, 16]

  pos 0 (<BOS>       ): [ 0.0800,  0.0257, -0.0117, -0.1056,  0.0339, -0.0891, -0.0083, -0.0737,  0.0107, -0.0291, -0.0100, -0.0312,  0.0214,  0.0372,  0.0105,  0.0279]
  pos 1 (I           ): [ 0.0683,  0.0368, -0.0263, -0.0574,  0.0152, -0.0174, -0.0084, -0.0760, -0.0199, -0.0151,  0.0026,  0.0107,  0.0091, -0.0204, -0.0320, -0.0193]
  pos 2 (like        ): [ 0.0247,  0.0789,  0.0074, -0.0635,  0.0180, -0.0098, -0.0184, -0.0173, -0.0320, -0.0102,  0.0178, -0.0153,  0.0433,  0.0026,  0.0002, -0.0198]
  pos 3 (transformers): [ 0.0254,  0.0511, -0.0182, -0.0322,  0.0103, -0.0126, -0.0282,  0.0018, -0.0111, -0.0085,  0.0093,  0.0101,  0.0440,  0.0237,  0.0056, -0.0311]
  pos 4 (<EOS>       ): [ 0.0325,  0.0367, -0.0202, -0.0262,  0.0188, -0.0040, -0.0321,  0.0167, -0.0119, -0.0013, -0.0069,  0.0016,  0.0480,  0.0233,  0.0096, -0.0121]
# Show the concatenation for position 0 in detail
print("Detailed: Concatenation for position 0 (<BOS>)")
print("=" * 70)
print()
print(f"Head 0 output (8 dims): {format_vector(attention_output_all[0][0])}")
print()
print(f"Head 1 output (8 dims): {format_vector(attention_output_all[1][0])}")
print()
print(f"Concatenated (16 dims): {format_vector(concat_output[0])}")
print()
print("First 8 elements come from head 0, last 8 from head 1.")
Detailed: Concatenation for position 0 (<BOS>)
======================================================================

Head 0 output (8 dims): [ 0.0800,  0.0257, -0.0117, -0.1056,  0.0339, -0.0891, -0.0083, -0.0737]

Head 1 output (8 dims): [ 0.0107, -0.0291, -0.0100, -0.0312,  0.0214,  0.0372,  0.0105,  0.0279]

Concatenated (16 dims): [ 0.0800,  0.0257, -0.0117, -0.1056,  0.0339, -0.0891, -0.0083, -0.0737,  0.0107, -0.0291, -0.0100, -0.0312,  0.0214,  0.0372,  0.0105,  0.0279]

First 8 elements come from head 0, last 8 from head 1.

Step 2: Output Projection

The concatenated output is [5, 16]. We project it through the output weight matrix WOW_O:

multi_head_output=concatWOT\text{multi\_head\_output} = \text{concat} \cdot W_O^T

Where WOW_O has shape [d_model, d_model] = [16, 16].

Why this projection?

Without WOW_O, the heads would be completely separate. The first 8 dimensions would always come from head 0, the last 8 from head 1. There’s no interaction.

With WOW_O, each output dimension becomes a learned combination of all dimensions from all heads. The model can learn:

  • “If head 0 found a verb and head 1 found the subject, strengthen the connection”

  • “Suppress noise when heads disagree”

  • “Combine syntactic and semantic signals into a unified representation”

This mixing is crucial for multi-head attention’s power.

# Initialize output projection matrix
W_O = random_matrix(D_MODEL, D_MODEL)  # [16, 16]

print(f"Output Projection Matrix W_O")
print(f"Shape: [{D_MODEL}, {D_MODEL}]")
print()
print("(16×16 = 256 learnable parameters for mixing heads)")
Output Projection Matrix W_O
Shape: [16, 16]

(16×16 = 256 learnable parameters for mixing heads)
# Apply output projection: concat @ W_O^T
W_O_T = transpose(W_O)
multi_head_output = matmul(concat_output, W_O_T)

print("Multi-Head Attention Output")
print(f"Shape: [{seq_len}, {D_MODEL}] @ [{D_MODEL}, {D_MODEL}] = [{seq_len}, {D_MODEL}]")
print()
for i, row in enumerate(multi_head_output):
    print(f"  pos {i} ({TOKEN_NAMES[tokens[i]]:12s}): {format_vector(row)}")
Multi-Head Attention Output
Shape: [5, 16] @ [16, 16] = [5, 16]

  pos 0 (<BOS>       ): [ 0.0334,  0.0033, -0.0041, -0.0073,  0.0185,  0.0074,  0.0169,  0.0107,  0.0277,  0.0060,  0.0222,  0.0241,  0.0074,  0.0067, -0.0067,  0.0063]
  pos 1 (I           ): [ 0.0269,  0.0066,  0.0113, -0.0154,  0.0114,  0.0032, -0.0065, -0.0108,  0.0190, -0.0091,  0.0180,  0.0097, -0.0075,  0.0061, -0.0079,  0.0110]
  pos 2 (like        ): [ 0.0085,  0.0086,  0.0159, -0.0177,  0.0026,  0.0205, -0.0057, -0.0055,  0.0059, -0.0043,  0.0007, -0.0053,  0.0075, -0.0012, -0.0043, -0.0016]
  pos 3 (transformers): [ 0.0064, -0.0084,  0.0092, -0.0173,  0.0068,  0.0119, -0.0100, -0.0027,  0.0027, -0.0073,  0.0036, -0.0076,  0.0022, -0.0070, -0.0095, -0.0070]
  pos 4 (<EOS>       ): [ 0.0039, -0.0068,  0.0098, -0.0136,  0.0031,  0.0090, -0.0086, -0.0027, -0.0003, -0.0044, -0.0029, -0.0062,  0.0060, -0.0048, -0.0036, -0.0115]
# Detailed calculation for one element
print("Detailed: Computing output[0][0]")
print("=" * 70)
print()
print("output[0][0] = concat[0] · W_O[:, 0]")
print()
print(f"concat[0] (16 dims): {format_vector(concat_output[0])}")
print()
col_0 = [W_O[i][0] for i in range(D_MODEL)]
print(f"W_O[:, 0] (16 dims): {format_vector(col_0)}")
print()

result = sum(concat_output[0][j] * col_0[j] for j in range(D_MODEL))
print(f"Dot product = {result:.6f}")
print(f"Actual output[0][0] = {multi_head_output[0][0]:.6f}")
Detailed: Computing output[0][0]
======================================================================

output[0][0] = concat[0] · W_O[:, 0]

concat[0] (16 dims): [ 0.0800,  0.0257, -0.0117, -0.1056,  0.0339, -0.0891, -0.0083, -0.0737,  0.0107, -0.0291, -0.0100, -0.0312,  0.0214,  0.0372,  0.0105,  0.0279]

W_O[:, 0] (16 dims): [-0.0262, -0.0621,  0.1544, -0.1022,  0.0636,  0.1289, -0.0064, -0.1005, -0.0360, -0.0405,  0.0004,  0.1525, -0.0449, -0.0366,  0.0846,  0.0188]

Dot product = -0.001460
Actual output[0][0] = 0.033421

Comparing Input and Output

Let’s compare what went into multi-head attention (the original embeddings XX) with what came out.

The output has the same shape as the input [5, 16], but each vector now incorporates information from other positions via attention. The representation for “like” now contains information about “I” and “<BOS>”—it’s no longer just about the word “like” in isolation.

print("Comparison: Input vs Multi-Head Attention Output")
print("=" * 80)
print()
for i in range(seq_len):
    print(f"Position {i} ({TOKEN_NAMES[tokens[i]]})")
    print(f"  Input X:  {format_vector(X[i])}")
    print(f"  Output:   {format_vector(multi_head_output[i])}")
    
    # Compute magnitude change
    mag_in = math.sqrt(sum(x**2 for x in X[i]))
    mag_out = math.sqrt(sum(x**2 for x in multi_head_output[i]))
    print(f"  Magnitude: {mag_in:.4f} → {mag_out:.4f}")
    print()
Comparison: Input vs Multi-Head Attention Output
================================================================================

Position 0 (<BOS>)
  Input X:  [ 0.1473,  0.1281,  0.1995, -0.0465,  0.2125, -0.1338, -0.0829, -0.0638,  0.0722,  0.1183,  0.1193,  0.0937, -0.1594, -0.0402,  0.1124, -0.2064]
  Output:   [ 0.0334,  0.0033, -0.0041, -0.0073,  0.0185,  0.0074,  0.0169,  0.0107,  0.0277,  0.0060,  0.0222,  0.0241,  0.0074,  0.0067, -0.0067,  0.0063]
  Magnitude: 0.5278 → 0.0637

Position 1 (I)
  Input X:  [-0.1254, -0.0720,  0.1255, -0.0556, -0.0678,  0.3698, -0.1265, -0.1463,  0.0866,  0.0181,  0.0726, -0.0374,  0.2312, -0.0091,  0.0860, -0.0251]
  Output:   [ 0.0269,  0.0066,  0.0113, -0.0154,  0.0114,  0.0032, -0.0065, -0.0108,  0.0190, -0.0091,  0.0180,  0.0097, -0.0075,  0.0061, -0.0079,  0.0110]
  Magnitude: 0.5427 → 0.0507

Position 2 (like)
  Input X:  [ 0.2319, -0.2747, -0.0089,  0.0576,  0.1430, -0.0957,  0.1571,  0.2913,  0.2154,  0.0103, -0.0510, -0.1353, -0.0296, -0.0371, -0.0262,  0.2770]
  Output:   [ 0.0085,  0.0086,  0.0159, -0.0177,  0.0026,  0.0205, -0.0057, -0.0055,  0.0059, -0.0043,  0.0007, -0.0053,  0.0075, -0.0012, -0.0043, -0.0016]
  Magnitude: 0.6472 → 0.0369

Position 3 (transformers)
  Input X:  [-0.1334,  0.0027, -0.3410, -0.1478, -0.0307,  0.1240,  0.2642, -0.0063, -0.0856,  0.0626,  0.1602,  0.1385, -0.0427,  0.0122,  0.0991,  0.1081]
  Output:   [ 0.0064, -0.0084,  0.0092, -0.0173,  0.0068,  0.0119, -0.0100, -0.0027,  0.0027, -0.0073,  0.0036, -0.0076,  0.0022, -0.0070, -0.0095, -0.0070]
  Magnitude: 0.5671 → 0.0334

Position 4 (<EOS>)
  Input X:  [-0.1346,  0.0002, -0.0629,  0.3029,  0.0908, -0.1515,  0.0959,  0.0481,  0.0032,  0.0225,  0.1310,  0.0306, -0.1088,  0.0649,  0.0880, -0.0130]
  Output:   [ 0.0039, -0.0068,  0.0098, -0.0136,  0.0031,  0.0090, -0.0086, -0.0027, -0.0003, -0.0044, -0.0029, -0.0062,  0.0060, -0.0048, -0.0036, -0.0115]
  Magnitude: 0.4462 → 0.0280

Parameter Count

Multi-head attention has a lot of parameters:

ComponentShapeCount
WQW_Q (per head)[16, 8] × 2 heads256
WKW_K (per head)[16, 8] × 2 heads256
WVW_V (per head)[16, 8] × 2 heads256
WOW_O[16, 16]256
Total1,024

That’s 1,024 parameters just for attention in one transformer block. In GPT-3 with d_model=12,288 and 96 heads, the attention parameters per layer are about 150 million.

What Multi-Head Attention Accomplished

We started with embeddings where each position was independent. After multi-head attention:

  • Each position’s representation incorporates information from previous positions

  • Two different “perspectives” (heads) contributed to this mixing

  • The output projection combined these perspectives into a unified representation

The model hasn’t changed the number of positions—we still have 5 vectors of dimension 16. But the content of those vectors is now context-dependent. This is the fundamental innovation of attention.

What’s Next

Multi-head attention is done. But we’re not finished with the transformer block.

Next comes the feed-forward network (FFN)—a simple two-layer neural network applied to each position independently. Where attention lets positions communicate, the FFN lets each position process the information it’s gathered.

# Store for next notebook
multi_head_data = {
    'X': X,
    'tokens': tokens,
    'multi_head_output': multi_head_output,
    'W_O': W_O,
    'concat_output': concat_output
}
print("Multi-head attention complete. Ready for feed-forward network.")
Multi-head attention complete. Ready for feed-forward network.