Scaled Dot-Product Attention
The core innovation of transformers: Attention is the mechanism that allows each word to “look at” and gather information from all other words in the sentence. This is what makes transformers so powerful.
A concrete example: Consider the sentence “The animal didn’t cross the street because it was too tired.” What does “it” refer to? A human knows “it” refers to “the animal” (not “the street”). Attention allows the model to learn this—when processing “it”, the model can attend strongly to “animal” and incorporate that information.
How Attention Works
Section titled “How Attention Works”The mechanism uses three components for each token, derived from the input embeddings through learned linear transformations:
- Query (Q): “What am I looking for?” - represents what the current token wants to know
- Key (K): “What do I contain?” - represents what information each token offers
- Value (V): “What information do I have?” - the actual content that gets passed along
The process: For each token, we compare its Query against all Keys (using dot products) to compute attention scores—how much should we pay attention to each other token? We normalize these scores with softmax to get probabilities (weights that sum to 1), then use these weights to take a weighted average of all Values.
The formula is elegant:
Attention(Q, K, V) = softmax(Q·Kᵀ / √d_k) · V
The division by √d_k is a scaling factor that prevents very large dot products in high dimensions, which would cause the softmax to produce near-zero gradients and make training difficult.
Implementation
Section titled “Implementation”def forward(self, query, key, value, mask=None): # Get dimension for scaling d_k = query.size(-1)
# Compute attention scores: Q·Kᵀ / √d_k scores = torch.matmul(query, key.transpose(-2, -1)) scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# Apply causal mask if provided (prevent looking at future) if mask is not None: scores = scores.masked_fill(mask == 1, float('-inf'))
# Apply softmax to get attention weights (probabilities) attention_weights = torch.softmax(scores, dim=-1)
# Apply attention weights to values output = torch.matmul(attention_weights, value)
return output, attention_weightsFull Code
Section titled “Full Code”See the full implementation: src/transformer/attention.py