Skip to content

Scaled Dot-Product Attention

The core innovation of transformers: Attention is the mechanism that allows each word to “look at” and gather information from all other words in the sentence. This is what makes transformers so powerful.

A concrete example: Consider the sentence “The animal didn’t cross the street because it was too tired.” What does “it” refer to? A human knows “it” refers to “the animal” (not “the street”). Attention allows the model to learn this—when processing “it”, the model can attend strongly to “animal” and incorporate that information.

The mechanism uses three components for each token, derived from the input embeddings through learned linear transformations:

  • Query (Q): “What am I looking for?” - represents what the current token wants to know
  • Key (K): “What do I contain?” - represents what information each token offers
  • Value (V): “What information do I have?” - the actual content that gets passed along

The process: For each token, we compare its Query against all Keys (using dot products) to compute attention scores—how much should we pay attention to each other token? We normalize these scores with softmax to get probabilities (weights that sum to 1), then use these weights to take a weighted average of all Values.

The formula is elegant:

Attention(Q, K, V) = softmax(Q·Kᵀ / √d_k) · V

The division by √d_k is a scaling factor that prevents very large dot products in high dimensions, which would cause the softmax to produce near-zero gradients and make training difficult.

Attention mechanism flow diagram

def forward(self, query, key, value, mask=None):
# Get dimension for scaling
d_k = query.size(-1)
# Compute attention scores: Q·Kᵀ / √d_k
scores = torch.matmul(query, key.transpose(-2, -1))
scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# Apply causal mask if provided (prevent looking at future)
if mask is not None:
scores = scores.masked_fill(mask == 1, float('-inf'))
# Apply softmax to get attention weights (probabilities)
attention_weights = torch.softmax(scores, dim=-1)
# Apply attention weights to values
output = torch.matmul(attention_weights, value)
return output, attention_weights

See the full implementation: src/transformer/attention.py