Multi-Head Attention
Why multiple heads? A single attention mechanism is powerful, but it can only learn one way of relating tokens. Multi-head attention runs several attention mechanisms in parallel (typically 8 or 16), each called a “head.” This gives the model multiple “perspectives” to understand relationships between words.
What Do Different Heads Learn?
Section titled “What Do Different Heads Learn?”Through training, different heads naturally specialize in different types of relationships. Research shows that real models develop heads that focus on:
- Syntactic relationships: One head might track subject-verb agreement
- Semantic relationships: Another head might connect related concepts
- Long-range dependencies: A head might link pronouns to their antecedents
- Local patterns: Another head might attend to adjacent words in phrases
How It Works
Section titled “How It Works”We split the d_model dimensions across heads. With d_model=512 and 8 heads, each head operates on 64 dimensions (512/8). All heads process the input in parallel, then we concatenate their outputs and apply a final linear transformation.
Implementation
Section titled “Implementation”def forward(self, x, mask=None): batch_size, seq_len, d_model = x.shape
# 1. Project input to Q, K, V Q = self.W_q(x) # (batch, seq_len, d_model) K = self.W_k(x) V = self.W_v(x)
# 2. Split into multiple heads # Reshape: (batch, seq_len, num_heads, d_k) Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# 3. Apply attention to each head (in parallel!) output, attn_weights = self.attention(Q, K, V, mask)
# 4. Concatenate heads back together output = output.transpose(1, 2).contiguous() output = output.view(batch_size, seq_len, d_model)
# 5. Final linear projection output = self.W_o(output)
return outputFull Code
Section titled “Full Code”See the full implementation: src/transformer/attention.py