Skip to content

Multi-Head Attention

Why multiple heads? A single attention mechanism is powerful, but it can only learn one way of relating tokens. Multi-head attention runs several attention mechanisms in parallel (typically 8 or 16), each called a “head.” This gives the model multiple “perspectives” to understand relationships between words.

Through training, different heads naturally specialize in different types of relationships. Research shows that real models develop heads that focus on:

  • Syntactic relationships: One head might track subject-verb agreement
  • Semantic relationships: Another head might connect related concepts
  • Long-range dependencies: A head might link pronouns to their antecedents
  • Local patterns: Another head might attend to adjacent words in phrases

We split the d_model dimensions across heads. With d_model=512 and 8 heads, each head operates on 64 dimensions (512/8). All heads process the input in parallel, then we concatenate their outputs and apply a final linear transformation.

Multi-head attention architecture

def forward(self, x, mask=None):
batch_size, seq_len, d_model = x.shape
# 1. Project input to Q, K, V
Q = self.W_q(x) # (batch, seq_len, d_model)
K = self.W_k(x)
V = self.W_v(x)
# 2. Split into multiple heads
# Reshape: (batch, seq_len, num_heads, d_k)
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# 3. Apply attention to each head (in parallel!)
output, attn_weights = self.attention(Q, K, V, mask)
# 4. Concatenate heads back together
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, d_model)
# 5. Final linear projection
output = self.W_o(output)
return output

See the full implementation: src/transformer/attention.py