The Chain Rule: Our Only Tool¶
In the last notebook, we computed —how the loss depends on the raw prediction scores. But we can’t modify logits directly. They’re computed from hidden states, which are computed from the FFN, which depends on attention, which depends on embeddings.
To train the model, we need gradients for all the learnable parameters:
Token embeddings ()
Position embeddings ()
Q, K, V projection matrices (, , for each head)
Output projection ()
FFN weights (, , , )
Layer norm parameters (, )
Language modeling head ()
The chain rule is how we get there. If the loss depends on , and depends on :
We already have from the previous layer. We just need to compute the local derivative and multiply.
This notebook works backward through every layer, computing gradients as we go.
import random
import math
random.seed(42)
VOCAB_SIZE = 6
D_MODEL = 16
D_FF = 64
MAX_SEQ_LEN = 5
NUM_HEADS = 2
D_K = D_MODEL // NUM_HEADS # 8
TOKEN_NAMES = ["<PAD>", "<BOS>", "<EOS>", "I", "like", "transformers"]# Helper functions
def random_vector(size, scale=0.1):
return [random.gauss(0, scale) for _ in range(size)]
def random_matrix(rows, cols, scale=0.1):
return [[random.gauss(0, scale) for _ in range(cols)] for _ in range(rows)]
def zeros_matrix(rows, cols):
return [[0.0] * cols for _ in range(rows)]
def zeros_vector(size):
return [0.0] * size
def format_vector(vec, decimals=4):
return "[" + ", ".join([f"{v:7.{decimals}f}" for v in vec]) + "]"Step 1: Output Layer (LM Head) Gradients¶
The output layer computes:
Where:
has shape
[seq_len, d_model]=[4, 16](hidden states going into prediction)has shape
[vocab_size, d_model]=[6, 16]has shape
[seq_len, vocab_size]=[4, 6]
We have from the previous notebook. We need:
— to update the weights
— to continue backpropagating
Gradient for a linear layer:
For (a general linear layer):
These formulas come from matrix calculus. The key insight is that in a linear layer, input and output are connected through the weight matrix in a way that’s “symmetric” for forward and backward passes.
# We have 4 positions that make predictions (positions 0-3 predict tokens 1-4)
seq_len = 4
# Loss gradients from previous notebook
dL_dlogits = [
[0.1785, 0.2007, 0.1759, -0.8746, 0.1563, 0.1632], # position 0: should predict I
[0.1836, 0.1969, 0.1805, 0.1233, -0.8500, 0.1657], # position 1: should predict like
[0.1795, 0.2050, 0.1782, 0.1207, 0.1437, -0.8272], # position 2: should predict transformers
[0.1855, 0.2017, -0.8229, 0.1271, 0.1391, 0.1695], # position 3: should predict <EOS>
]
# Simulated hidden states (would come from forward pass)
h = [random_vector(D_MODEL) for _ in range(seq_len)]
# Simulated W_lm (would be initialized at model creation)
W_lm = random_matrix(VOCAB_SIZE, D_MODEL)
print(f"Shapes:")
print(f" dL_dlogits: [{seq_len}, {VOCAB_SIZE}]")
print(f" h: [{seq_len}, {D_MODEL}]")
print(f" W_lm: [{VOCAB_SIZE}, {D_MODEL}]")Shapes:
dL_dlogits: [4, 6]
h: [4, 16]
W_lm: [6, 16]
# Compute gradient for W_lm
# dL_dW_lm[i][j] = sum over positions of dL_dlogits[pos][i] * h[pos][j]
dL_dW_lm = zeros_matrix(VOCAB_SIZE, D_MODEL)
for pos in range(seq_len):
for i in range(VOCAB_SIZE): # vocabulary index
for j in range(D_MODEL): # embedding dimension
dL_dW_lm[i][j] += dL_dlogits[pos][i] * h[pos][j]
print("Gradient for W_lm (dL/dW_lm)")
print(f"Shape: [{VOCAB_SIZE}, {D_MODEL}]")
print()
print("First row (gradient for <PAD> token's projection):")
print(f" {format_vector(dL_dW_lm[0])}")Gradient for W_lm (dL/dW_lm)
Shape: [6, 16]
First row (gradient for <PAD> token's projection):
[-0.0567, -0.0032, -0.0264, 0.0356, 0.0148, -0.0320, 0.0002, -0.0341, 0.0249, 0.0188, 0.0225, 0.0555, 0.0640, -0.0139, -0.0175, -0.0472]
# Compute gradient for hidden states (to continue backprop)
# dL_dh[pos][j] = sum over vocab of dL_dlogits[pos][i] * W_lm[i][j]
dL_dh = zeros_matrix(seq_len, D_MODEL)
for pos in range(seq_len):
for j in range(D_MODEL): # embedding dimension
for i in range(VOCAB_SIZE): # vocabulary index
dL_dh[pos][j] += dL_dlogits[pos][i] * W_lm[i][j]
print("Gradient for hidden states (dL/dh)")
print(f"Shape: [{seq_len}, {D_MODEL}]")
print()
print("Position 0 gradient:")
print(f" {format_vector(dL_dh[0])}")Gradient for hidden states (dL/dh)
Shape: [4, 16]
Position 0 gradient:
[ 0.0903, -0.0471, -0.1804, 0.0674, 0.0800, -0.1086, 0.1034, 0.0616, 0.0108, 0.1447, 0.0161, 0.1209, -0.1277, 0.0034, -0.1746, 0.0298]
Step 2: Layer Normalization Gradients¶
Layer norm is more complex because normalizing one element affects the mean and variance, which affects all elements. The formula is:
Where and are computed from itself.
Gradients for parameters:
Where is the normalized input.
Gradient for input (to continue backprop):
This involves the Jacobian of layer norm, which is a bit involved because changing one input element affects and . The full derivation is:
# Simulated normalized values (would come from forward pass)
x_norm = [random_vector(D_MODEL) for _ in range(seq_len)] # x_hat = (x - mean) / std
# Initial gamma = 1, beta = 0 (standard initialization)
gamma = [1.0] * D_MODEL
beta = [0.0] * D_MODEL
# Compute gradients for gamma and beta
dL_dgamma = zeros_vector(D_MODEL)
dL_dbeta = zeros_vector(D_MODEL)
for pos in range(seq_len):
for j in range(D_MODEL):
dL_dgamma[j] += dL_dh[pos][j] * x_norm[pos][j]
dL_dbeta[j] += dL_dh[pos][j]
print("Layer Norm Gradients")
print("="*50)
print()
print("Gradient for gamma (first 8 of 16 values):")
print(f" {format_vector(dL_dgamma[:8])}")
print()
print("Gradient for beta (first 8 of 16 values):")
print(f" {format_vector(dL_dbeta[:8])}")Layer Norm Gradients
==================================================
Gradient for gamma (first 8 of 16 values):
[ 0.0225, -0.0240, -0.0296, 0.0103, 0.0170, -0.0249, 0.0179, -0.0028]
Gradient for beta (first 8 of 16 values):
[-0.0287, -0.0846, -0.2322, -0.0055, -0.1733, 0.0561, 0.0873, 0.1953]
Step 3: Feed-Forward Network Gradients¶
The FFN computes:
Breaking this into steps:
— Linear projection (expand to 64 dims)
— Activation function
— Linear projection (back to 16 dims)
We backprop through each in reverse order.
GELU derivative:
GELU is defined as where is the standard Gaussian CDF.
Its derivative is:
Where is the Gaussian PDF.
def gelu_derivative(x):
"""
Derivative of GELU activation function.
GELU(x) = x * Phi(x) where Phi is standard Gaussian CDF.
GELU'(x) = Phi(x) + x * phi(x) where phi is Gaussian PDF.
"""
# Gaussian CDF approximation (same as in forward GELU)
cdf = 0.5 * (1 + math.tanh(math.sqrt(2/math.pi) * (x + 0.044715 * x**3)))
# Gaussian PDF
pdf = math.exp(-x**2 / 2) / math.sqrt(2 * math.pi)
return cdf + x * pdf
# Show GELU derivative behavior
print("GELU Derivative Values")
print("="*40)
print()
print(f"{'x':>8} {'GELU(x)':>12} {'GELU\'(x)':>12}")
print("-"*40)
def gelu(x):
return 0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi) * (x + 0.044715 * x**3)))
for x in [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]:
print(f"{x:>8.1f} {gelu(x):>12.4f} {gelu_derivative(x):>12.4f}")GELU Derivative Values
========================================
x GELU(x) GELU'(x)
----------------------------------------
-2.0 -0.0454 -0.0853
-1.0 -0.1588 -0.0832
-0.5 -0.1543 0.1325
0.0 0.0000 0.5000
0.5 0.3457 0.8675
1.0 0.8412 1.0832
2.0 1.9546 1.0853
# Simulated values from forward pass
h1 = [random_vector(D_FF) for _ in range(seq_len)] # Before GELU [4, 64]
h2 = [[gelu(val) for val in row] for row in h1] # After GELU [4, 64]
x_ffn = [random_vector(D_MODEL) for _ in range(seq_len)] # Input to FFN [4, 16]
# Weights (would be initialized at model creation)
W1 = random_matrix(D_FF, D_MODEL) # [64, 16]
W2 = random_matrix(D_MODEL, D_FF) # [16, 64]
b1 = random_vector(D_FF) # [64]
b2 = random_vector(D_MODEL) # [16]
# Gradient flowing in (after residual connection handling)
dL_dy = dL_dh # [4, 16]
print("FFN Shapes:")
print(f" W1: [{D_FF}, {D_MODEL}] (expand)")
print(f" W2: [{D_MODEL}, {D_FF}] (project)")
print(f" h1, h2: [{seq_len}, {D_FF}]")FFN Shapes:
W1: [64, 16] (expand)
W2: [16, 64] (project)
h1, h2: [4, 64]
# Step 3a: Gradient for W2 and b2 (second linear layer)
# y = h2 @ W2^T + b2
dL_dW2 = zeros_matrix(D_MODEL, D_FF)
dL_db2 = zeros_vector(D_MODEL)
for pos in range(seq_len):
for i in range(D_MODEL):
dL_db2[i] += dL_dy[pos][i]
for j in range(D_FF):
dL_dW2[i][j] += dL_dy[pos][i] * h2[pos][j]
print("FFN Second Layer Gradients")
print(f" dL_dW2 shape: [{D_MODEL}, {D_FF}]")
print(f" dL_db2 shape: [{D_MODEL}]")
print()
print(f" dL_db2 (first 8): {format_vector(dL_db2[:8])}")FFN Second Layer Gradients
dL_dW2 shape: [16, 64]
dL_db2 shape: [16]
dL_db2 (first 8): [-0.0287, -0.0846, -0.2322, -0.0055, -0.1733, 0.0561, 0.0873, 0.1953]
# Step 3b: Gradient flowing to h2 (for backprop through GELU)
# dL_dh2[pos][j] = sum_i dL_dy[pos][i] * W2[i][j]
dL_dh2 = zeros_matrix(seq_len, D_FF)
for pos in range(seq_len):
for j in range(D_FF):
for i in range(D_MODEL):
dL_dh2[pos][j] += dL_dy[pos][i] * W2[i][j]
# Step 3c: Backprop through GELU (element-wise)
# dL_dh1[pos][j] = dL_dh2[pos][j] * GELU'(h1[pos][j])
dL_dh1 = zeros_matrix(seq_len, D_FF)
for pos in range(seq_len):
for j in range(D_FF):
dL_dh1[pos][j] = dL_dh2[pos][j] * gelu_derivative(h1[pos][j])
print("Backprop through GELU")
print(f" dL_dh2 shape: [{seq_len}, {D_FF}]")
print(f" dL_dh1 shape: [{seq_len}, {D_FF}]")Backprop through GELU
dL_dh2 shape: [4, 64]
dL_dh1 shape: [4, 64]
# Step 3d: Gradient for W1 and b1 (first linear layer)
# h1 = x @ W1^T + b1
dL_dW1 = zeros_matrix(D_FF, D_MODEL)
dL_db1 = zeros_vector(D_FF)
for pos in range(seq_len):
for i in range(D_FF):
dL_db1[i] += dL_dh1[pos][i]
for j in range(D_MODEL):
dL_dW1[i][j] += dL_dh1[pos][i] * x_ffn[pos][j]
print("FFN First Layer Gradients")
print(f" dL_dW1 shape: [{D_FF}, {D_MODEL}]")
print(f" dL_db1 shape: [{D_FF}]")
print()
print(f" dL_db1 (first 8 of 64): {format_vector(dL_db1[:8])}")FFN First Layer Gradients
dL_dW1 shape: [64, 16]
dL_db1 shape: [64]
dL_db1 (first 8 of 64): [-0.0364, 0.0281, 0.0095, -0.0106, -0.0364, 0.0285, -0.0004, -0.0165]
Step 4: Attention Gradients¶
This is the most complex part. The attention mechanism involves:
Q, K, V projections: , ,
Attention scores:
Softmax:
Weighted values:
Output projection:
We need to backprop through all of these. Let’s start with the output projection.
# Simulated concatenated attention output
concat_attn = [random_vector(D_MODEL) for _ in range(seq_len)]
W_O = random_matrix(D_MODEL, D_MODEL)
# Gradient for W_O
# multi_head = concat @ W_O^T
# dL_dW_O[i][j] = sum_pos dL_dmh[pos][i] * concat[pos][j]
dL_dW_O = zeros_matrix(D_MODEL, D_MODEL)
for pos in range(seq_len):
for i in range(D_MODEL):
for j in range(D_MODEL):
dL_dW_O[i][j] += dL_dh[pos][i] * concat_attn[pos][j]
print("Output Projection Gradient")
print(f" dL_dW_O shape: [{D_MODEL}, {D_MODEL}]")
print(f" First row: {format_vector(dL_dW_O[0][:8])}...")Output Projection Gradient
dL_dW_O shape: [16, 16]
First row: [-0.0054, -0.0005, -0.0072, -0.0334, 0.0020, 0.0262, -0.0032, 0.0370]...
# Gradient flowing to concatenated attention
dL_dconcat = zeros_matrix(seq_len, D_MODEL)
for pos in range(seq_len):
for j in range(D_MODEL):
for i in range(D_MODEL):
dL_dconcat[pos][j] += dL_dh[pos][i] * W_O[i][j]
print("Gradient for concatenated attention output")
print(f" Shape: [{seq_len}, {D_MODEL}]")Gradient for concatenated attention output
Shape: [4, 16]
Attention Weight Gradients¶
For the Q, K, V projection matrices, we need to backprop through:
The weighted sum:
The softmax operation
The scaled dot product:
The linear projections
This is where backpropagation gets intricate. Each head has its own , , , and we need gradients for all of them.
# Simulated values for one attention head
# In a full implementation, we'd do this for each head
X = [random_vector(D_MODEL) for _ in range(seq_len)] # Input
W_Q_head0 = random_matrix(D_MODEL, D_K) # [16, 8]
W_K_head0 = random_matrix(D_MODEL, D_K) # [16, 8]
W_V_head0 = random_matrix(D_MODEL, D_K) # [16, 8]
# Q, K, V for head 0 (computed in forward pass)
Q = [[sum(X[i][k] * W_Q_head0[k][j] for k in range(D_MODEL)) for j in range(D_K)] for i in range(seq_len)]
K = [[sum(X[i][k] * W_K_head0[k][j] for k in range(D_MODEL)) for j in range(D_K)] for i in range(seq_len)]
V = [[sum(X[i][k] * W_V_head0[k][j] for k in range(D_MODEL)) for j in range(D_K)] for i in range(seq_len)]
print("Attention Head 0 Matrices")
print(f" W_Q: [{D_MODEL}, {D_K}]")
print(f" W_K: [{D_MODEL}, {D_K}]")
print(f" W_V: [{D_MODEL}, {D_K}]")
print(f" Q, K, V: [{seq_len}, {D_K}]")Attention Head 0 Matrices
W_Q: [16, 8]
W_K: [16, 8]
W_V: [16, 8]
Q, K, V: [4, 8]
# Simplified gradient computation for W_Q (head 0)
# In full backprop, we'd compute this through the attention mechanism
# For demonstration, assume we have dL_dQ (gradient flowing into Q)
dL_dQ = [random_vector(D_K) for _ in range(seq_len)] # [4, 8]
# Gradient for W_Q: dL_dW_Q[i][j] = sum_pos dL_dQ[pos][j] * X[pos][i]
dL_dW_Q = zeros_matrix(D_MODEL, D_K)
for pos in range(seq_len):
for i in range(D_MODEL):
for j in range(D_K):
dL_dW_Q[i][j] += dL_dQ[pos][j] * X[pos][i]
print("Gradient for W_Q (head 0)")
print(f" Shape: [{D_MODEL}, {D_K}]")
print(f" First row: {format_vector(dL_dW_Q[0])}")Gradient for W_Q (head 0)
Shape: [16, 8]
First row: [-0.0017, -0.0116, -0.0035, 0.0058, 0.0073, -0.0046, -0.0093, -0.0092]
Step 5: Embedding Gradients¶
Finally, we compute gradients for the embedding matrices.
Token embeddings ():
The embedding lookup is just indexing: . So the gradient only flows to the rows that were actually used.
Position embeddings ():
Same idea—gradient accumulates for each position that was used.
# Tokens used in our sequence
tokens_used = [1, 3, 4, 5, 2] # <BOS>, I, like, transformers, <EOS>
full_seq_len = 5 # Including the last position
# Gradient flowing into embeddings (would come from backprop through attention)
dL_dX = [random_vector(D_MODEL) for _ in range(full_seq_len)]
# Gradient for token embeddings
dL_dE_token = zeros_matrix(VOCAB_SIZE, D_MODEL)
for pos, token_id in enumerate(tokens_used):
for j in range(D_MODEL):
dL_dE_token[token_id][j] += dL_dX[pos][j]
print("Token Embedding Gradients")
print("="*50)
print()
for i, name in enumerate(TOKEN_NAMES):
grad_norm = math.sqrt(sum(g**2 for g in dL_dE_token[i]))
if grad_norm > 0:
print(f" {name:12s}: ||gradient|| = {grad_norm:.4f}")
else:
print(f" {name:12s}: no gradient (token not used)")Token Embedding Gradients
==================================================
<PAD> : no gradient (token not used)
<BOS> : ||gradient|| = 0.4340
<EOS> : ||gradient|| = 0.3268
I : ||gradient|| = 0.2963
like : ||gradient|| = 0.3597
transformers: ||gradient|| = 0.3583
# Gradient for position embeddings
dL_dE_pos = zeros_matrix(MAX_SEQ_LEN, D_MODEL)
for pos in range(full_seq_len):
for j in range(D_MODEL):
dL_dE_pos[pos][j] = dL_dX[pos][j]
print("Position Embedding Gradients")
print("="*50)
print()
for pos in range(full_seq_len):
grad_norm = math.sqrt(sum(g**2 for g in dL_dE_pos[pos]))
print(f" Position {pos}: ||gradient|| = {grad_norm:.4f}")Position Embedding Gradients
==================================================
Position 0: ||gradient|| = 0.4340
Position 1: ||gradient|| = 0.2963
Position 2: ||gradient|| = 0.3597
Position 3: ||gradient|| = 0.3583
Position 4: ||gradient|| = 0.3268
Summary: All Gradients Computed¶
We’ve traced the chain rule backward through the entire network:
| Layer | Parameters | Gradient Shape | Purpose |
|---|---|---|---|
| LM Head | [6, 16] | Predict next token | |
| Layer Norm | , | [16], [16] | Normalize activations |
| FFN | , | [16, 64], [16] | Project back |
| , | [64, 16], [64] | Expand to hidden | |
| Attention | [16, 16] | Output projection | |
| (×2 heads) | [16, 8] | Query projection | |
| (×2 heads) | [16, 8] | Key projection | |
| (×2 heads) | [16, 8] | Value projection | |
| Embeddings | [6, 16] | Token vectors | |
| [5, 16] | Position vectors |
Total: ~2,600 parameters, each with its own gradient.
The Key Insight: Local Computation, Global Effect¶
Backpropagation is beautiful because each layer only needs to know:
What it computed during the forward pass
The gradient flowing in from the layer above
It doesn’t need to know about the loss function, the other layers, or anything else. Each layer computes its local derivatives and passes the gradient backward.
Yet when we’re done, every parameter has a gradient that tells us exactly how it contributed to the final loss—even parameters that are 10 layers removed from the output.
What’s Next: The Optimizer¶
We have gradients for all ~2,600 parameters. The gradient tells us which direction reduces the loss.
The simplest approach would be gradient descent:
But modern transformers use AdamW, which is much more sophisticated:
Adaptive learning rates: Each parameter gets its own learning rate based on gradient history
Momentum: Smooth out noisy gradients by averaging over time
Weight decay: Regularize by shrinking weights toward zero
The next notebook implements AdamW and completes our training loop.
print("Backpropagation Complete")
print("="*50)
print()
print("Gradients computed for:")
print(" - LM head (96 params)")
print(" - Layer norm (32 params)")
print(" - FFN (2,128 params)")
print(" - Attention (1,024 params)")
print(" - Embeddings (~176 params)")
print()
print("Ready for optimization step.")Backpropagation Complete
==================================================
Gradients computed for:
- LM head (96 params)
- Layer norm (32 params)
- FFN (2,128 params)
- Attention (1,024 params)
- Embeddings (~176 params)
Ready for optimization step.