The Backward Pass Begins¶
The forward pass is done. We fed “I like transformers” through our tiny transformer and got a loss of about 1.9—essentially random guessing. The model has no idea what it’s doing.
Now we need to fix that. We need to figure out how to adjust the ~2,600 parameters so the model does better next time.
The tool for this is backpropagation—the algorithm that computes how much each parameter contributed to the final loss. Once we know that, we can nudge each parameter in the direction that reduces the loss.
The big picture:
For every parameter in the model, we want to compute:
This is the gradient—it tells us how much the loss would change if we increased by a tiny amount. Positive gradient means increasing increases loss (bad). Negative gradient means increasing decreases loss (good).
The strategy:
We work backward from the loss. First we compute how the loss depends on the logits. Then how the logits depend on the hidden states. Then how the hidden states depend on the FFN. And so on, all the way back to the embeddings.
This is the chain rule in action:
Cross-Entropy + Softmax: A Beautiful Gradient¶
Let’s start at the very end of the computation graph. Our loss is:
Where comes from softmax:
We need for each vocabulary token .
The derivation:
This requires some calculus, but the result is remarkably clean. Let’s work through it.
For the loss where is the target:
The first term is straightforward:
The second term (softmax derivative) depends on whether :
If :
If :
Combining these:
If :
If :
The final formula:
Where is the indicator function: 1 if equals the target, 0 otherwise.
This is one of the most elegant results in machine learning. The gradient is just: predicted probability minus target probability.
import random
import math
random.seed(42)
VOCAB_SIZE = 6
D_MODEL = 16
D_FF = 64
MAX_SEQ_LEN = 5
NUM_HEADS = 2
D_K = D_MODEL // NUM_HEADS
TOKEN_NAMES = ["<PAD>", "<BOS>", "<EOS>", "I", "like", "transformers"]# For this notebook, we'll use pre-computed probabilities from the forward pass
# These are approximately what a randomly initialized model would produce
probs = [
[0.1785, 0.2007, 0.1759, 0.1254, 0.1563, 0.1632], # position 0: <BOS>
[0.1836, 0.1969, 0.1805, 0.1233, 0.1500, 0.1657], # position 1: I
[0.1795, 0.2050, 0.1782, 0.1207, 0.1437, 0.1728], # position 2: like
[0.1855, 0.2017, 0.1771, 0.1271, 0.1391, 0.1695], # position 3: transformers
]
# Target tokens (what the model should predict at each position)
targets = [3, 4, 5, 2] # I, like, transformers, <EOS>
# Current tokens (input at each position)
tokens = [1, 3, 4, 5, 2] # <BOS>, I, like, transformers, <EOS>
print("Forward pass predictions:")
print("="*60)
print()
for i, t in enumerate(targets):
prob_correct = probs[i][t]
print(f"Position {i}: {TOKEN_NAMES[tokens[i]]:12s} → should predict {TOKEN_NAMES[t]:12s} (P = {prob_correct:.4f})")Forward pass predictions:
============================================================
Position 0: <BOS> → should predict I (P = 0.1254)
Position 1: I → should predict like (P = 0.1500)
Position 2: like → should predict transformers (P = 0.1728)
Position 3: transformers → should predict <EOS> (P = 0.1771)
Computing the Gradient¶
Now let’s compute for each position and each vocabulary token.
The formula is simple:
For the target token: gradient = (always negative)
For all other tokens: gradient = (always positive)
def compute_loss_gradient(probs, target):
"""
Compute gradient of cross-entropy loss w.r.t. logits.
dL/dlogit[i] = P(i) - 1 if i == target
dL/dlogit[i] = P(i) otherwise
Args:
probs: Probability distribution from softmax [vocab_size]
target: Index of the correct token
Returns:
Gradient vector [vocab_size]
"""
grad = probs.copy()
grad[target] -= 1.0 # Subtract 1 from the target position
return grad
# Compute gradients for all positions
dL_dlogits = []
for i in range(len(targets)):
grad = compute_loss_gradient(probs[i], targets[i])
dL_dlogits.append(grad)
print("Loss Gradients w.r.t. Logits (dL/dlogit)")
print("="*75)
print()
print(f"{'Position':<12} {'<PAD>':>9} {'<BOS>':>9} {'<EOS>':>9} {'I':>9} {'like':>9} {'trans':>9}")
print("-"*75)
for i, grad in enumerate(dL_dlogits):
target = targets[i]
row = f"{TOKEN_NAMES[tokens[i]]:<12}"
for j, g in enumerate(grad):
marker = "*" if j == target else " "
row += f" {g:>8.4f}{marker}"
print(row)
print()
print("* marks the target token (negative gradient)")Loss Gradients w.r.t. Logits (dL/dlogit)
===========================================================================
Position <PAD> <BOS> <EOS> I like trans
---------------------------------------------------------------------------
<BOS> 0.1785 0.2007 0.1759 -0.8746* 0.1563 0.1632
I 0.1836 0.1969 0.1805 0.1233 -0.8500* 0.1657
like 0.1795 0.2050 0.1782 0.1207 0.1437 -0.8272*
transformers 0.1855 0.2017 -0.8229* 0.1271 0.1391 0.1695
* marks the target token (negative gradient)
Understanding the Gradients¶
Let’s break down position 0 in detail. The model sees <BOS> and should predict I.
print("Detailed: Position 0 (<BOS> → should predict I)")
print("="*70)
print()
print("Step 1: Current probabilities from softmax")
print("-"*50)
for j, name in enumerate(TOKEN_NAMES):
marker = "← TARGET" if j == targets[0] else ""
print(f" P({name:12s}) = {probs[0][j]:.4f} {marker}")
print(f" Sum = {sum(probs[0]):.4f}")
print()
print("Step 2: Compute gradients using formula: dL/dlogit = P(i) - 1[i==target]")
print("-"*70)
for j, name in enumerate(TOKEN_NAMES):
is_target = 1 if j == targets[0] else 0
grad = probs[0][j] - is_target
print(f" dL/dlogit[{name:12s}] = {probs[0][j]:.4f} - {is_target} = {grad:>8.4f}")Detailed: Position 0 (<BOS> → should predict I)
======================================================================
Step 1: Current probabilities from softmax
--------------------------------------------------
P(<PAD> ) = 0.1785
P(<BOS> ) = 0.2007
P(<EOS> ) = 0.1759
P(I ) = 0.1254 ← TARGET
P(like ) = 0.1563
P(transformers) = 0.1632
Sum = 1.0000
Step 2: Compute gradients using formula: dL/dlogit = P(i) - 1[i==target]
----------------------------------------------------------------------
dL/dlogit[<PAD> ] = 0.1785 - 0 = 0.1785
dL/dlogit[<BOS> ] = 0.2007 - 0 = 0.2007
dL/dlogit[<EOS> ] = 0.1759 - 0 = 0.1759
dL/dlogit[I ] = 0.1254 - 1 = -0.8746
dL/dlogit[like ] = 0.1563 - 0 = 0.1563
dL/dlogit[transformers] = 0.1632 - 0 = 0.1632
Why These Gradients Make Sense¶
Key insight: Gradients point in the direction of increasing loss.
When we do gradient descent, we subtract the gradient (times a learning rate):
So:
For the correct token (gradient = ):
We subtract a negative number
That’s the same as adding
The logit increases
Higher logit → higher probability ✓
For incorrect tokens (gradient = ):
We subtract a positive number
The logit decreases
Lower logit → lower probability ✓
Gradient descent naturally pushes the correct answer’s probability up and everything else down.
# Demonstrate what happens with a gradient update
learning_rate = 0.1
# Pretend these are the logits (we'll make up some values)
logits_before = [-0.5, 0.2, -0.3, -0.8, -0.2, 0.1] # Position 0
print("Effect of one gradient descent step (learning rate = 0.1)")
print("="*70)
print()
print(f"{'Token':<12} {'Logit Before':>12} {'Gradient':>10} {'Logit After':>12} {'Change':>10}")
print("-"*70)
logits_after = []
for j, name in enumerate(TOKEN_NAMES):
grad = dL_dlogits[0][j]
new_logit = logits_before[j] - learning_rate * grad
logits_after.append(new_logit)
change = new_logit - logits_before[j]
marker = "← TARGET" if j == targets[0] else ""
print(f"{name:<12} {logits_before[j]:>12.4f} {grad:>10.4f} {new_logit:>12.4f} {change:>+10.4f} {marker}")
print()
print("The target token's logit increased; all others decreased.")Effect of one gradient descent step (learning rate = 0.1)
======================================================================
Token Logit Before Gradient Logit After Change
----------------------------------------------------------------------
<PAD> -0.5000 0.1785 -0.5179 -0.0179
<BOS> 0.2000 0.2007 0.1799 -0.0201
<EOS> -0.3000 0.1759 -0.3176 -0.0176
I -0.8000 -0.8746 -0.7125 +0.0875 ← TARGET
like -0.2000 0.1563 -0.2156 -0.0156
transformers 0.1000 0.1632 0.0837 -0.0163
The target token's logit increased; all others decreased.
Mathematical Property: Gradients Sum to Zero¶
Here’s a nice verification we can do. The gradients at each position should sum to zero:
This makes sense: softmax is translation-invariant. Adding a constant to all logits doesn’t change the probabilities (it cancels out in the normalization). So the gradient with respect to “shift all logits equally” should be zero.
print("Verification: Gradients sum to zero at each position")
print("="*50)
print()
for i, grad in enumerate(dL_dlogits):
grad_sum = sum(grad)
status = "✓" if abs(grad_sum) < 1e-6 else f"✗ (off by {grad_sum:.2e})"
print(f"Position {i} ({TOKEN_NAMES[tokens[i]]:12s}): sum = {grad_sum:>12.10f} {status}")Verification: Gradients sum to zero at each position
==================================================
Position 0 (<BOS> ): sum = -0.0000000000 ✓
Position 1 (I ): sum = 0.0000000000 ✓
Position 2 (like ): sum = -0.0001000000 ✗ (off by -1.00e-04)
Position 3 (transformers): sum = 0.0000000000 ✓
Gradient Magnitude Tells a Story¶
Notice the magnitudes of the gradients:
Target gradient: ~-0.85 (large negative)
Non-target gradients: ~+0.17 (small positive)
The target gradient is much larger in magnitude. Why?
Because is large when the model is wrong. Our model assigns only ~15% probability to the correct answer, so .
As training progresses:
If : gradient becomes (small)
If : gradient becomes (tiny)
The gradient naturally gets smaller as the model improves—the learning signal weakens when there’s less to learn. This is a form of adaptive learning.
# Show how gradient magnitude changes with model confidence
print("How gradient magnitude depends on model confidence")
print("="*55)
print()
print(f"{'P(target)':>12} {'Target Gradient':>18} {'Interpretation':<25}")
print("-"*55)
test_probs = [0.01, 0.10, 0.25, 0.50, 0.75, 0.90, 0.99]
for p in test_probs:
grad = p - 1
if p < 0.2:
interp = "Very wrong, strong signal"
elif p < 0.5:
interp = "Uncertain, moderate signal"
elif p < 0.8:
interp = "Getting it, weaker signal"
else:
interp = "Confident, tiny signal"
print(f"{p:>12.2f} {grad:>18.4f} {interp:<25}")How gradient magnitude depends on model confidence
=======================================================
P(target) Target Gradient Interpretation
-------------------------------------------------------
0.01 -0.9900 Very wrong, strong signal
0.10 -0.9000 Very wrong, strong signal
0.25 -0.7500 Uncertain, moderate signal
0.50 -0.5000 Getting it, weaker signal
0.75 -0.2500 Getting it, weaker signal
0.90 -0.1000 Confident, tiny signal
0.99 -0.0100 Confident, tiny signal
The Gradient Tensor Shape¶
Let’s be explicit about what we’ve computed:
4 positions (we don’t predict after
<EOS>)6 vocabulary tokens
Each entry tells us how much the loss would change if we increased that particular logit by a tiny amount.
print(f"Gradient tensor shape: [{len(dL_dlogits)}, {len(dL_dlogits[0])}]")
print(f" - 4 positions (0 through 3, predicting tokens 1 through 4)")
print(f" - 6 vocabulary tokens")
print(f" - Total: 24 gradient values")Gradient tensor shape: [4, 6]
- 4 positions (0 through 3, predicting tokens 1 through 4)
- 6 vocabulary tokens
- Total: 24 gradient values
What’s Next: Backpropagating Further¶
We’ve computed . This tells us how the loss depends on the raw prediction scores.
But we can’t directly modify the logits—they’re computed from earlier layers:
To train the model, we need gradients for:
(the language modeling head weights) - so we can update them
hidden states - so we can backpropagate further into the transformer
And then we’ll continue backward through:
Layer normalization
Residual connections
Feed-forward network
Multi-head attention
Q/K/V projections
Embeddings
The next notebook continues the backward pass through these layers.
# Store gradients for next notebook
grad_loss_data = {
'dL_dlogits': dL_dlogits,
'probs': probs,
'targets': targets,
'tokens': tokens
}
print("Summary: Loss Gradients")
print("="*50)
print(f"Shape: [4, 6] (4 positions × 6 vocab tokens)")
print(f"Formula: dL/dlogit[i] = P(i) - 1[i==target]")
print(f"Property: Each row sums to 0")
print()
print("Ready for backpropagation through the model.")Summary: Loss Gradients
==================================================
Shape: [4, 6] (4 positions × 6 vocab tokens)
Formula: dL/dlogit[i] = P(i) - 1[i==target]
Property: Each row sums to 0
Ready for backpropagation through the model.