Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Multi-Head Attention

Why Multiple Heads?

We just computed attention separately for two heads. Each head has its own Q, K, V projections, so each head learns to look at the data differently.

Think of it like having multiple experts examine the same text:

  • One head might specialize in syntactic relationships (subject-verb, noun-adjective)

  • Another might focus on semantic relationships (related concepts, coreference)

  • Another might track positional patterns (nearby words, document structure)

In practice, what heads learn is emergent; we don’t design them for specific tasks. But research has shown that different heads genuinely specialize. Some learn to track quotation marks. Others focus on rare words. Others capture long-range dependencies.

The multi-head mechanism lets the model attend to information from different representation subspaces simultaneously. It’s one of the key innovations that made transformers so effective.

The Multi-Head Process

We have attention outputs from 2 heads, each with shape [5, 8]. To get back to d_model = 16:

  1. Concatenate the head outputs: [5, 8] and [5, 8][5, 16]

  2. Project through output matrix WOW_O: [5, 16] @ [16, 16][5, 16]

MultiHead(X)=Concat(head0,head1)WO\text{MultiHead}(X) = \text{Concat}(\text{head}_0, \text{head}_1) \cdot W_O

The concatenation is straightforward. Just stack the vectors side by side.

The output projection through WOW_O is important: it lets the model mix information across heads. Without it, the heads would remain completely independent. With it, the model can learn combinations like “if head 0 sees X and head 1 sees Y, output Z.”

import random
import math

random.seed(42)

VOCAB_SIZE = 6
D_MODEL = 16
MAX_SEQ_LEN = 5
NUM_HEADS = 2
D_K = D_MODEL // NUM_HEADS  # 8

TOKEN_NAMES = ["<PAD>", "<BOS>", "<EOS>", "I", "like", "transformers"]
# Helper functions
def random_vector(size, scale=0.1):
    return [random.gauss(0, scale) for _ in range(size)]

def random_matrix(rows, cols, scale=0.1):
    return [[random.gauss(0, scale) for _ in range(cols)] for _ in range(rows)]

def add_vectors(v1, v2):
    return [a + b for a, b in zip(v1, v2)]

def matmul(A, B):
    m, n, p = len(A), len(A[0]), len(B[0])
    return [[sum(A[i][k] * B[k][j] for k in range(n)) for j in range(p)] for i in range(m)]

def transpose(A):
    rows, cols = len(A), len(A[0])
    return [[A[i][j] for i in range(rows)] for j in range(cols)]

def softmax(vec):
    finite_vals = [v for v in vec if v != float('-inf')]
    max_val = max(finite_vals) if finite_vals else 0
    exp_vec = [math.exp(v - max_val) if v != float('-inf') else 0.0 for v in vec]
    total = sum(exp_vec)
    return [e / total for e in exp_vec]

def format_vector(vec, decimals=4):
    return "[" + ", ".join([f"{v:7.{decimals}f}" for v in vec]) + "]"
# Recreate everything from previous notebooks
E_token = [random_vector(D_MODEL) for _ in range(VOCAB_SIZE)]
E_pos = [random_vector(D_MODEL) for _ in range(MAX_SEQ_LEN)]
tokens = [1, 3, 4, 5, 2]
seq_len = len(tokens)
X = [add_vectors(E_token[tokens[i]], E_pos[i]) for i in range(seq_len)]

W_Q = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
W_K = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
W_V = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]

Q_all = [matmul(X, W_Q[h]) for h in range(NUM_HEADS)]
K_all = [matmul(X, W_K[h]) for h in range(NUM_HEADS)]
V_all = [matmul(X, W_V[h]) for h in range(NUM_HEADS)]

def attention(Q, K, V):
    seq_len, d_k = len(Q), len(Q[0])
    scale = math.sqrt(d_k)
    scores = matmul(Q, transpose(K))
    scaled = [[s / scale for s in row] for row in scores]
    for i in range(seq_len):
        for j in range(seq_len):
            if j > i:
                scaled[i][j] = float('-inf')
    weights = [softmax(row) for row in scaled]
    return matmul(weights, V), weights

attention_output_all = []
for h in range(NUM_HEADS):
    output, _ = attention(Q_all[h], K_all[h], V_all[h])
    attention_output_all.append(output)

print(f"Recreated attention outputs from previous notebooks")
print(f"Head 0 output shape: [{len(attention_output_all[0])}, {len(attention_output_all[0][0])}]")
print(f"Head 1 output shape: [{len(attention_output_all[1])}, {len(attention_output_all[1][0])}]")
Recreated attention outputs from previous notebooks
Head 0 output shape: [5, 8]
Head 1 output shape: [5, 8]

Step 1: Concatenate Head Outputs

Each head produced an output of shape [5, 8]. We concatenate them along the last dimension:

concat[i]=[head0[i]    head1[i]]\text{concat}[i] = [\text{head}_0[i] \;\|\; \text{head}_1[i]]

For each position ii, we take the 8-dimensional vector from head 0 and the 8-dimensional vector from head 1, and stick them together to get a 16-dimensional vector.

Result shape: [5, 16]

# Concatenate head outputs
concat_output = []
for i in range(seq_len):
    # Concatenate head 0 and head 1 outputs for position i
    concat_row = attention_output_all[0][i] + attention_output_all[1][i]
    concat_output.append(concat_row)

print("Concatenated Attention Outputs")
print(f"Shape: [{seq_len}, {D_MODEL}]")
print()
for i, row in enumerate(concat_output):
    print(f"  pos {i} ({TOKEN_NAMES[tokens[i]]:12s}): {format_vector(row)}")
Concatenated Attention Outputs
Shape: [5, 16]

  pos 0 (<BOS>       ): [ 0.0800,  0.0257, -0.0117, -0.1056,  0.0339, -0.0891, -0.0083, -0.0737,  0.0107, -0.0291, -0.0100, -0.0312,  0.0214,  0.0372,  0.0105,  0.0279]
  pos 1 (I           ): [ 0.0683,  0.0368, -0.0263, -0.0574,  0.0152, -0.0174, -0.0084, -0.0760, -0.0199, -0.0151,  0.0026,  0.0107,  0.0091, -0.0204, -0.0320, -0.0193]
  pos 2 (like        ): [ 0.0247,  0.0789,  0.0074, -0.0635,  0.0180, -0.0098, -0.0184, -0.0173, -0.0320, -0.0102,  0.0178, -0.0153,  0.0433,  0.0026,  0.0002, -0.0198]
  pos 3 (transformers): [ 0.0254,  0.0511, -0.0182, -0.0322,  0.0103, -0.0126, -0.0282,  0.0018, -0.0111, -0.0085,  0.0093,  0.0101,  0.0440,  0.0237,  0.0056, -0.0311]
  pos 4 (<EOS>       ): [ 0.0325,  0.0367, -0.0202, -0.0262,  0.0188, -0.0040, -0.0321,  0.0167, -0.0119, -0.0013, -0.0069,  0.0016,  0.0480,  0.0233,  0.0096, -0.0121]
# Show the concatenation for position 0 in detail
print("Detailed: Concatenation for position 0 (<BOS>)")
print("=" * 70)
print()
print(f"Head 0 output (8 dims): {format_vector(attention_output_all[0][0])}")
print()
print(f"Head 1 output (8 dims): {format_vector(attention_output_all[1][0])}")
print()
print(f"Concatenated (16 dims): {format_vector(concat_output[0])}")
print()
print("First 8 elements come from head 0, last 8 from head 1.")
Detailed: Concatenation for position 0 (<BOS>)
======================================================================

Head 0 output (8 dims): [ 0.0800,  0.0257, -0.0117, -0.1056,  0.0339, -0.0891, -0.0083, -0.0737]

Head 1 output (8 dims): [ 0.0107, -0.0291, -0.0100, -0.0312,  0.0214,  0.0372,  0.0105,  0.0279]

Concatenated (16 dims): [ 0.0800,  0.0257, -0.0117, -0.1056,  0.0339, -0.0891, -0.0083, -0.0737,  0.0107, -0.0291, -0.0100, -0.0312,  0.0214,  0.0372,  0.0105,  0.0279]

First 8 elements come from head 0, last 8 from head 1.

Step 2: Output Projection

The concatenated output is [5, 16]. We project it through the output weight matrix WOW_O:

multi_head_output=concatWOT\text{multi\_head\_output} = \text{concat} \cdot W_O^T

Where WOW_O has shape [d_model, d_model] = [16, 16].

Why this projection?

Without WOW_O, the heads would be completely separate. The first 8 dimensions would always come from head 0, the last 8 from head 1. There’s no interaction.

With WOW_O, each output dimension becomes a learned combination of all dimensions from all heads. The model can learn:

  • “If head 0 found a verb and head 1 found the subject, strengthen the connection”

  • “Suppress noise when heads disagree”

  • “Combine syntactic and semantic signals into a unified representation”

This mixing is crucial for multi-head attention’s power.

# Initialize output projection matrix
W_O = random_matrix(D_MODEL, D_MODEL)  # [16, 16]

print(f"Output Projection Matrix W_O")
print(f"Shape: [{D_MODEL}, {D_MODEL}]")
print()
print("(16×16 = 256 learnable parameters for mixing heads)")
Output Projection Matrix W_O
Shape: [16, 16]

(16×16 = 256 learnable parameters for mixing heads)
# Apply output projection: concat @ W_O^T
W_O_T = transpose(W_O)
multi_head_output = matmul(concat_output, W_O_T)

print("Multi-Head Attention Output")
print(f"Shape: [{seq_len}, {D_MODEL}] @ [{D_MODEL}, {D_MODEL}] = [{seq_len}, {D_MODEL}]")
print()
for i, row in enumerate(multi_head_output):
    print(f"  pos {i} ({TOKEN_NAMES[tokens[i]]:12s}): {format_vector(row)}")
Multi-Head Attention Output
Shape: [5, 16] @ [16, 16] = [5, 16]

  pos 0 (<BOS>       ): [ 0.0334,  0.0033, -0.0041, -0.0073,  0.0185,  0.0074,  0.0169,  0.0107,  0.0277,  0.0060,  0.0222,  0.0241,  0.0074,  0.0067, -0.0067,  0.0063]
  pos 1 (I           ): [ 0.0269,  0.0066,  0.0113, -0.0154,  0.0114,  0.0032, -0.0065, -0.0108,  0.0190, -0.0091,  0.0180,  0.0097, -0.0075,  0.0061, -0.0079,  0.0110]
  pos 2 (like        ): [ 0.0085,  0.0086,  0.0159, -0.0177,  0.0026,  0.0205, -0.0057, -0.0055,  0.0059, -0.0043,  0.0007, -0.0053,  0.0075, -0.0012, -0.0043, -0.0016]
  pos 3 (transformers): [ 0.0064, -0.0084,  0.0092, -0.0173,  0.0068,  0.0119, -0.0100, -0.0027,  0.0027, -0.0073,  0.0036, -0.0076,  0.0022, -0.0070, -0.0095, -0.0070]
  pos 4 (<EOS>       ): [ 0.0039, -0.0068,  0.0098, -0.0136,  0.0031,  0.0090, -0.0086, -0.0027, -0.0003, -0.0044, -0.0029, -0.0062,  0.0060, -0.0048, -0.0036, -0.0115]
# Detailed calculation for one element
print("Detailed: Computing output[0][0]")
print("=" * 70)
print()
print("output[0][0] = concat[0] · W_O[:, 0]")
print()
print(f"concat[0] (16 dims): {format_vector(concat_output[0])}")
print()
col_0 = [W_O[i][0] for i in range(D_MODEL)]
print(f"W_O[:, 0] (16 dims): {format_vector(col_0)}")
print()

result = sum(concat_output[0][j] * col_0[j] for j in range(D_MODEL))
print(f"Dot product = {result:.6f}")
print(f"Actual output[0][0] = {multi_head_output[0][0]:.6f}")
Detailed: Computing output[0][0]
======================================================================

output[0][0] = concat[0] · W_O[:, 0]

concat[0] (16 dims): [ 0.0800,  0.0257, -0.0117, -0.1056,  0.0339, -0.0891, -0.0083, -0.0737,  0.0107, -0.0291, -0.0100, -0.0312,  0.0214,  0.0372,  0.0105,  0.0279]

W_O[:, 0] (16 dims): [-0.0262, -0.0621,  0.1544, -0.1022,  0.0636,  0.1289, -0.0064, -0.1005, -0.0360, -0.0405,  0.0004,  0.1525, -0.0449, -0.0366,  0.0846,  0.0188]

Dot product = -0.001460
Actual output[0][0] = 0.033421

Comparing Input and Output

Let’s compare what went into multi-head attention (the original embeddings XX) with what came out.

The output has the same shape as the input [5, 16], but each vector now incorporates information from other positions via attention. The representation for “like” now contains information about “I” and “<BOS>”. it’s no longer just about the word “like” in isolation.

print("Comparison: Input vs Multi-Head Attention Output")
print("=" * 80)
print()
for i in range(seq_len):
    print(f"Position {i} ({TOKEN_NAMES[tokens[i]]})")
    print(f"  Input X:  {format_vector(X[i])}")
    print(f"  Output:   {format_vector(multi_head_output[i])}")
    
    # Compute magnitude change
    mag_in = math.sqrt(sum(x**2 for x in X[i]))
    mag_out = math.sqrt(sum(x**2 for x in multi_head_output[i]))
    print(f"  Magnitude: {mag_in:.4f} → {mag_out:.4f}")
    print()
Comparison: Input vs Multi-Head Attention Output
================================================================================

Position 0 (<BOS>)
  Input X:  [ 0.1473,  0.1281,  0.1995, -0.0465,  0.2125, -0.1338, -0.0829, -0.0638,  0.0722,  0.1183,  0.1193,  0.0937, -0.1594, -0.0402,  0.1124, -0.2064]
  Output:   [ 0.0334,  0.0033, -0.0041, -0.0073,  0.0185,  0.0074,  0.0169,  0.0107,  0.0277,  0.0060,  0.0222,  0.0241,  0.0074,  0.0067, -0.0067,  0.0063]
  Magnitude: 0.5278 → 0.0637

Position 1 (I)
  Input X:  [-0.1254, -0.0720,  0.1255, -0.0556, -0.0678,  0.3698, -0.1265, -0.1463,  0.0866,  0.0181,  0.0726, -0.0374,  0.2312, -0.0091,  0.0860, -0.0251]
  Output:   [ 0.0269,  0.0066,  0.0113, -0.0154,  0.0114,  0.0032, -0.0065, -0.0108,  0.0190, -0.0091,  0.0180,  0.0097, -0.0075,  0.0061, -0.0079,  0.0110]
  Magnitude: 0.5427 → 0.0507

Position 2 (like)
  Input X:  [ 0.2319, -0.2747, -0.0089,  0.0576,  0.1430, -0.0957,  0.1571,  0.2913,  0.2154,  0.0103, -0.0510, -0.1353, -0.0296, -0.0371, -0.0262,  0.2770]
  Output:   [ 0.0085,  0.0086,  0.0159, -0.0177,  0.0026,  0.0205, -0.0057, -0.0055,  0.0059, -0.0043,  0.0007, -0.0053,  0.0075, -0.0012, -0.0043, -0.0016]
  Magnitude: 0.6472 → 0.0369

Position 3 (transformers)
  Input X:  [-0.1334,  0.0027, -0.3410, -0.1478, -0.0307,  0.1240,  0.2642, -0.0063, -0.0856,  0.0626,  0.1602,  0.1385, -0.0427,  0.0122,  0.0991,  0.1081]
  Output:   [ 0.0064, -0.0084,  0.0092, -0.0173,  0.0068,  0.0119, -0.0100, -0.0027,  0.0027, -0.0073,  0.0036, -0.0076,  0.0022, -0.0070, -0.0095, -0.0070]
  Magnitude: 0.5671 → 0.0334

Position 4 (<EOS>)
  Input X:  [-0.1346,  0.0002, -0.0629,  0.3029,  0.0908, -0.1515,  0.0959,  0.0481,  0.0032,  0.0225,  0.1310,  0.0306, -0.1088,  0.0649,  0.0880, -0.0130]
  Output:   [ 0.0039, -0.0068,  0.0098, -0.0136,  0.0031,  0.0090, -0.0086, -0.0027, -0.0003, -0.0044, -0.0029, -0.0062,  0.0060, -0.0048, -0.0036, -0.0115]
  Magnitude: 0.4462 → 0.0280

Parameter Count

Multi-head attention has a lot of parameters:

ComponentShapeCount
WQW_Q (per head)[16, 8] × 2 heads256
WKW_K (per head)[16, 8] × 2 heads256
WVW_V (per head)[16, 8] × 2 heads256
WOW_O[16, 16]256
Total1,024

That’s 1,024 parameters just for attention in one transformer block. In GPT-3 with d_model=12,288 and 96 heads, the attention parameters per layer are about 150 million.

What Multi-Head Attention Accomplished

We started with embeddings where each position was independent. After multi-head attention:

  • Each position’s representation incorporates information from previous positions

  • Two different “perspectives” (heads) contributed to this mixing

  • The output projection combined these perspectives into a unified representation

The model hasn’t changed the number of positions. we still have 5 vectors of dimension 16. But the content of those vectors is now context-dependent. This is the fundamental innovation of attention.

What’s Next

Multi-head attention is done. But we’re not finished with the transformer block.

Next comes the feed-forward network (FFN). A simple two-layer neural network applied to each position independently. Where attention lets positions communicate, the FFN lets each position process the information it’s gathered.

# Store for next notebook
multi_head_data = {
    'X': X,
    'tokens': tokens,
    'multi_head_output': multi_head_output,
    'W_O': W_O,
    'concat_output': concat_output
}
print("Multi-head attention complete. Ready for feed-forward network.")
Multi-head attention complete. Ready for feed-forward network.