The Final Step: Actually Learning¶
We’ve traced the entire forward pass through our tiny transformer. We’ve computed the loss. We’ve backpropagated gradients through every layer, all the way back to the embeddings.
Now we have ~2,600 gradients—one for each parameter in the model. Each gradient tells us: “If you increase this parameter slightly, the loss will change by approximately this much.”
The question is: what do we do with these gradients?
The simplest answer would be plain stochastic gradient descent (SGD):
Subtract the gradient (scaled by a learning rate ), and you’re done.
But SGD has problems. Modern transformers use something much more sophisticated: AdamW.
Why SGD Isn’t Enough¶
Plain gradient descent has several failure modes:
1. Same learning rate for all parameters
Some parameters need large updates (they’re far from optimal). Others need tiny updates (they’re already close). SGD uses the same learning rate for everyone, which means you have to choose a rate that’s safe for the most sensitive parameters—and that’s too slow for the rest.
2. Noisy gradients cause oscillation
Gradients computed on mini-batches are noisy estimates of the true gradient. SGD follows each noisy gradient directly, which can cause the optimization to zigzag back and forth instead of heading straight toward the minimum.
3. Sensitive to learning rate
Too high a learning rate and training diverges. Too low and it takes forever. The right learning rate depends on the loss surface, which changes during training.
AdamW addresses all of these problems.
What AdamW Does¶
AdamW combines three powerful ideas:
1. Adaptive learning rates (from Adam)
Each parameter gets its own effective learning rate, based on the history of its gradients. Parameters with consistently large gradients get smaller learning rates (we’re already moving fast). Parameters with small gradients get larger learning rates (we need to push harder).
2. Momentum (from Adam)
Instead of following each noisy gradient directly, we track an exponential moving average of past gradients. This smooths out the noise and helps the optimizer maintain direction even when individual gradients are noisy.
3. Weight decay (the “W” in AdamW)
Regularization that shrinks weights toward zero. This prevents overfitting by penalizing large weights. AdamW applies weight decay directly to parameters, separate from the gradient update. (The original Adam paper applied weight decay through the gradient, which interacted poorly with the adaptive learning rates.)
import math
# AdamW hyperparameters (standard values from the literature)
learning_rate = 0.001 # α (alpha): base learning rate
beta1 = 0.9 # β₁: decay rate for first moment (momentum)
beta2 = 0.999 # β₂: decay rate for second moment (adaptive LR)
epsilon = 1e-8 # ε: small constant for numerical stability
weight_decay = 0.01 # λ: weight decay coefficient
print("AdamW Hyperparameters")
print("="*50)
print(f" learning_rate (α) = {learning_rate}")
print(f" beta1 (β₁) = {beta1}")
print(f" beta2 (β₂) = {beta2}")
print(f" epsilon (ε) = {epsilon}")
print(f" weight_decay (λ) = {weight_decay}")
print()
print("These values are standard—most transformer training")
print("uses exactly these numbers.")AdamW Hyperparameters
==================================================
learning_rate (α) = 0.001
beta1 (β₁) = 0.9
beta2 (β₂) = 0.999
epsilon (ε) = 1e-08
weight_decay (λ) = 0.01
These values are standard—most transformer training
uses exactly these numbers.
The AdamW Algorithm¶
For each parameter , AdamW maintains two state variables across training steps:
— First moment estimate: An exponential moving average of the gradient (this is momentum)
— Second moment estimate: An exponential moving average of the squared gradient (this enables adaptive learning rates)
At each time step , given gradient :
Step 1: Update biased moment estimates
The first moment accumulates gradient direction (momentum). The second moment accumulates gradient magnitude.
Step 2: Bias correction
Since and are initialized to zero, they’re biased toward zero in early steps. This correction compensates for that bias.
Step 3: Weight decay
Shrink the parameter toward zero by a small amount. This is regularization.
Step 4: Parameter update
Update the parameter using the bias-corrected moments. The in the denominator creates the adaptive learning rate.
def adamw_update(theta, gradient, m, v, t,
lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8, wd=0.01):
"""
Perform one AdamW update step for a single parameter.
Args:
theta: Current parameter value
gradient: Gradient of loss w.r.t. this parameter
m: First moment estimate (momentum)
v: Second moment estimate (for adaptive LR)
t: Time step (starts at 1, increments each update)
lr: Base learning rate (alpha)
beta1: First moment decay rate
beta2: Second moment decay rate
eps: Small constant for numerical stability
wd: Weight decay coefficient
Returns:
new_theta: Updated parameter value
new_m: Updated first moment
new_v: Updated second moment
"""
# Step 1: Update biased moments
m_new = beta1 * m + (1 - beta1) * gradient
v_new = beta2 * v + (1 - beta2) * gradient**2
# Step 2: Bias correction
m_hat = m_new / (1 - beta1**t)
v_hat = v_new / (1 - beta2**t)
# Step 3: Weight decay (applied before gradient update)
theta_decayed = theta * (1 - lr * wd)
# Step 4: Parameter update
theta_new = theta_decayed - lr * m_hat / (math.sqrt(v_hat) + eps)
return theta_new, m_new, v_newWorked Example: Updating One Parameter¶
Let’s walk through a complete AdamW update for a single parameter. We’ll update the first element of the <BOS> token’s embedding.
Suppose:
Current value:
Gradient: (negative means increasing would decrease loss)
This is time step (first update)
Moments initialized to zero: ,
# Initial state for our example parameter
theta = 0.024634 # Current value of E_token[<BOS>][0]
g = -0.352893 # Gradient (negative = we should increase theta)
m_prev = 0.0 # First moment (initialized to 0)
v_prev = 0.0 # Second moment (initialized to 0)
t = 1 # Time step (first update)
print("Initial State")
print("="*50)
print(f" Parameter θ = {theta:.6f}")
print(f" Gradient g = {g:.6f}")
print(f" First moment m = {m_prev:.6f}")
print(f" Second moment v = {v_prev:.6f}")
print(f" Time step t = {t}")Initial State
==================================================
Parameter θ = 0.024634
Gradient g = -0.352893
First moment m = 0.000000
Second moment v = 0.000000
Time step t = 1
# Step 1: Update biased moments
m_new = beta1 * m_prev + (1 - beta1) * g
v_new = beta2 * v_prev + (1 - beta2) * g**2
print("Step 1: Update Biased Moment Estimates")
print("="*60)
print()
print("First moment (momentum):")
print(f" m₁ = β₁ × m₀ + (1 - β₁) × g")
print(f" = {beta1} × {m_prev} + {1-beta1:.1f} × {g:.6f}")
print(f" = {m_new:.6f}")
print()
print("Second moment (for adaptive LR):")
print(f" v₁ = β₂ × v₀ + (1 - β₂) × g²")
print(f" = {beta2} × {v_prev} + {1-beta2:.3f} × ({g:.6f})²")
print(f" = {beta2} × {v_prev} + {1-beta2:.3f} × {g**2:.6f}")
print(f" = {v_new:.9f}")Step 1: Update Biased Moment Estimates
============================================================
First moment (momentum):
m₁ = β₁ × m₀ + (1 - β₁) × g
= 0.9 × 0.0 + 0.1 × -0.352893
= -0.035289
Second moment (for adaptive LR):
v₁ = β₂ × v₀ + (1 - β₂) × g²
= 0.999 × 0.0 + 0.001 × (-0.352893)²
= 0.999 × 0.0 + 0.001 × 0.124533
= 0.000124533
# Step 2: Bias correction
m_hat = m_new / (1 - beta1**t)
v_hat = v_new / (1 - beta2**t)
print("Step 2: Bias Correction")
print("="*60)
print()
print("Bias-corrected first moment:")
print(f" m̂₁ = m₁ / (1 - β₁^t)")
print(f" = {m_new:.6f} / (1 - {beta1}^{t})")
print(f" = {m_new:.6f} / {1 - beta1**t:.1f}")
print(f" = {m_hat:.6f}")
print()
print("Bias-corrected second moment:")
print(f" v̂₁ = v₁ / (1 - β₂^t)")
print(f" = {v_new:.9f} / (1 - {beta2}^{t})")
print(f" = {v_new:.9f} / {1 - beta2**t:.3f}")
print(f" = {v_hat:.6f}")
print()
print(f"Note: At t=1, m̂ equals the gradient ({m_hat:.6f} ≈ {g:.6f}).")
print(f"This is expected—we haven't accumulated any history yet.")Step 2: Bias Correction
============================================================
Bias-corrected first moment:
m̂₁ = m₁ / (1 - β₁^t)
= -0.035289 / (1 - 0.9^1)
= -0.035289 / 0.1
= -0.352893
Bias-corrected second moment:
v̂₁ = v₁ / (1 - β₂^t)
= 0.000124533 / (1 - 0.999^1)
= 0.000124533 / 0.001
= 0.124533
Note: At t=1, m̂ equals the gradient (-0.352893 ≈ -0.352893).
This is expected—we haven't accumulated any history yet.
# Step 3: Weight decay
decay_factor = 1 - learning_rate * weight_decay
theta_decayed = theta * decay_factor
print("Step 3: Weight Decay")
print("="*60)
print()
print(f" θ_decayed = θ × (1 - α × λ)")
print(f" = {theta:.6f} × (1 - {learning_rate} × {weight_decay})")
print(f" = {theta:.6f} × {decay_factor:.5f}")
print(f" = {theta_decayed:.9f}")
print()
print(f"Weight shrinkage: {theta - theta_decayed:.9f}")
print(f"(Tiny! Weight decay is a gentle regularization.)") Step 3: Weight Decay
============================================================
θ_decayed = θ × (1 - α × λ)
= 0.024634 × (1 - 0.001 × 0.01)
= 0.024634 × 0.99999
= 0.024633754
Weight shrinkage: 0.000000246
(Tiny! Weight decay is a gentle regularization.)
# Step 4: Parameter update
denominator = math.sqrt(v_hat) + epsilon
adaptive_lr = learning_rate / denominator
update_amount = learning_rate * m_hat / denominator
theta_new = theta_decayed - update_amount
print("Step 4: Parameter Update")
print("="*60)
print()
print("Compute the adaptive learning rate:")
print(f" Effective LR = α / (√v̂ + ε)")
print(f" = {learning_rate} / (√{v_hat:.6f} + {epsilon})")
print(f" = {learning_rate} / ({math.sqrt(v_hat):.6f} + {epsilon})")
print(f" = {learning_rate} / {denominator:.6f}")
print(f" = {adaptive_lr:.6f}")
print()
print(f"This is {adaptive_lr/learning_rate:.2f}× the base learning rate!")
print()
print("Compute the update:")
print(f" Update = α × m̂ / (√v̂ + ε)")
print(f" = {learning_rate} × {m_hat:.6f} / {denominator:.6f}")
print(f" = {update_amount:.6f}")
print()
print("Apply to parameter:")
print(f" θ_new = θ_decayed - update")
print(f" = {theta_decayed:.6f} - ({update_amount:.6f})")
print(f" = {theta_new:.6f}")Step 4: Parameter Update
============================================================
Compute the adaptive learning rate:
Effective LR = α / (√v̂ + ε)
= 0.001 / (√0.124533 + 1e-08)
= 0.001 / (0.352893 + 1e-08)
= 0.001 / 0.352893
= 0.002834
This is 2.83× the base learning rate!
Compute the update:
Update = α × m̂ / (√v̂ + ε)
= 0.001 × -0.352893 / 0.352893
= -0.001000
Apply to parameter:
θ_new = θ_decayed - update
= 0.024634 - (-0.001000)
= 0.025634
# Summary
print("="*60)
print("SUMMARY: One AdamW Update")
print("="*60)
print()
print(f"Parameter: E_token[<BOS>][0]")
print(f" Before: θ = {theta:.6f}")
print(f" After: θ = {theta_new:.6f}")
print(f" Change: {theta_new - theta:+.6f}")
print()
print(f"The gradient was negative ({g:.6f}), meaning:")
print(f" 'Increasing θ would decrease the loss.'")
print(f"So AdamW increased θ. The model learned slightly!")============================================================
SUMMARY: One AdamW Update
============================================================
Parameter: E_token[<BOS>][0]
Before: θ = 0.024634
After: θ = 0.025634
Change: +0.001000
The gradient was negative (-0.352893), meaning:
'Increasing θ would decrease the loss.'
So AdamW increased θ. The model learned slightly!
Why the Adaptive Learning Rate Matters¶
Notice that our effective learning rate was ~2.83× the base learning rate. This happened because:
The denominator is essentially the RMS (root mean square) of recent gradients for this parameter.
Large gradients → Large → Small effective LR → Cautious updates
Small gradients → Small → Large effective LR → Aggressive updates
This automatically adapts to each parameter’s gradient scale. Parameters deep in the network often have smaller gradients (due to the chain rule), but Adam compensates by giving them larger effective learning rates.
# Demonstrate how v_hat affects the effective learning rate
print("How Gradient Magnitude Affects Learning Rate")
print("="*60)
print()
print(f"{'|gradient|':>12} {'v̂ (≈g²)':>12} {'Effective LR':>15} {'Multiplier':>12}")
print("-"*60)
test_grads = [0.001, 0.01, 0.1, 0.5, 1.0, 5.0]
for g_test in test_grads:
v_test = g_test**2 # Simplified: v ≈ g² at t=1
eff_lr = learning_rate / (math.sqrt(v_test) + epsilon)
mult = eff_lr / learning_rate
print(f"{g_test:>12.3f} {v_test:>12.6f} {eff_lr:>15.6f} {mult:>12.2f}×")
print()
print("Small gradients get large multipliers; large gradients get small ones.")
print("This is what makes Adam 'adaptive'.")How Gradient Magnitude Affects Learning Rate
============================================================
|gradient| v̂ (≈g²) Effective LR Multiplier
------------------------------------------------------------
0.001 0.000001 0.999990 999.99×
0.010 0.000100 0.100000 100.00×
0.100 0.010000 0.010000 10.00×
0.500 0.250000 0.002000 2.00×
1.000 1.000000 0.001000 1.00×
5.000 25.000000 0.000200 0.20×
Small gradients get large multipliers; large gradients get small ones.
This is what makes Adam 'adaptive'.
Updating All Parameters¶
We apply this same AdamW update to every single parameter in the model. Let’s count them:
| Component | Shape | Count |
|---|---|---|
| Token embeddings | [6, 16] | 96 |
| Position embeddings | [5, 16] | 80 |
| Attention (×2 heads) | [16, 8] × 2 | 256 |
| Attention (×2 heads) | [16, 8] × 2 | 256 |
| Attention (×2 heads) | [16, 8] × 2 | 256 |
| Attention | [16, 16] | 256 |
| FFN | [64, 16] | 1,024 |
| FFN | [64] | 64 |
| FFN | [16, 64] | 1,024 |
| FFN | [16] | 16 |
| Layer norm | [16] | 16 |
| Layer norm | [16] | 16 |
| LM head | [6, 16] | 96 |
| Total | ~3,456 |
Each parameter has its own and state, which means AdamW needs to store 2× as many values as there are parameters. This is the memory cost of adaptive optimizers.
# Parameter count breakdown
param_counts = {
"Token embeddings": 6 * 16,
"Position embeddings": 5 * 16,
"W_Q (2 heads)": 2 * 16 * 8,
"W_K (2 heads)": 2 * 16 * 8,
"W_V (2 heads)": 2 * 16 * 8,
"W_O": 16 * 16,
"FFN W_1": 64 * 16,
"FFN b_1": 64,
"FFN W_2": 16 * 64,
"FFN b_2": 16,
"Layer norm γ": 16,
"Layer norm β": 16,
"LM head W_lm": 6 * 16,
}
print("Parameter Count")
print("="*40)
total = 0
for name, count in param_counts.items():
print(f" {name:25s}: {count:>6,}")
total += count
print("-"*40)
print(f" {'Total':25s}: {total:>6,}")
print()
print(f"AdamW state (m and v): {2 * total:,} values")Parameter Count
========================================
Token embeddings : 96
Position embeddings : 80
W_Q (2 heads) : 256
W_K (2 heads) : 256
W_V (2 heads) : 256
W_O : 256
FFN W_1 : 1,024
FFN b_1 : 64
FFN W_2 : 1,024
FFN b_2 : 16
Layer norm γ : 16
Layer norm β : 16
LM head W_lm : 96
----------------------------------------
Total : 3,456
AdamW state (m and v): 6,912 values
The Complete Training Loop¶
We’ve now traced through one complete training step:
Forward pass (Notebooks 01-07)
Tokenize input → Embeddings → Attention → FFN → Layer Norm → Logits → Loss
Backward pass (Notebooks 08-09)
Compute → Backprop through each layer → Gradients for all parameters
Optimization (This notebook)
Update each parameter using AdamW
In pseudocode:
# Initialize model parameters randomly
# Initialize Adam state (m=0, v=0 for each parameter)
t = 0
for epoch in range(num_epochs):
for batch in training_data:
t += 1
# Forward pass
predictions = model.forward(batch)
loss = cross_entropy(predictions, targets)
# Backward pass
gradients = model.backward(loss)
# Optimization
for param in model.parameters:
param.value, param.m, param.v = adamw_update(
param.value, param.gradient, param.m, param.v, t
)Repeat this loop millions of times. Each iteration, the loss gets smaller. The model gets better.
Scaling Up¶
Our tiny model has ~3,500 parameters. Real language models have:
| Model | Parameters |
|---|---|
| GPT-2 Small | 124 million |
| GPT-2 Large | 774 million |
| GPT-3 | 175 billion |
| LLaMA 2 70B | 70 billion |
| Claude 3 | Unknown (but large) |
The math is identical. Every one of those billions of parameters goes through the same AdamW update we just computed. The difference is scale—and the engineering required to parallelize the computation across thousands of GPUs.
But the algorithm is exactly what we’ve shown: forward pass, backward pass, AdamW update. Repeat.
What We’ve Accomplished¶
We calculated—by hand, with explicit numbers—a complete training step through a transformer:
Forward pass:
Tokenization and embedding lookup
Q/K/V projections for attention
Scaled dot-product attention with causal masking
Multi-head concatenation and output projection
Feed-forward network with GELU activation
Layer normalization with residual connections
Output projection and softmax
Cross-entropy loss computation
Backward pass:
Gradient of loss with respect to logits
Backpropagation through every layer using the chain rule
Gradients for all ~3,500 parameters
Optimization:
AdamW update with momentum, adaptive learning rates, and weight decay
Nothing was hidden. No magic. Just math.
Closing Thoughts¶
You’ve made it through the entire pipeline.
You’ve seen every matrix multiplication, every activation function, every gradient calculation, every weight update. You understand how a transformer processes text, why attention works, what backpropagation computes, and how optimization nudges parameters toward better predictions.
When someone says “a transformer learns by gradient descent,” you now know exactly what that means—down to the individual floating-point operations.
This is the difference between knowing about something and truly understanding it. You don’t just know that transformers use attention mechanisms—you’ve computed the attention scores yourself. You don’t just know that neural networks use backpropagation—you’ve traced the chain rule through every layer.
The next time you use GPT, Claude, or any language model, remember: under the hood, it’s doing exactly what we did here. Billions of times larger. Trillions of times repeated. But the same fundamental math.
Transformers aren’t magic. They’re just math.
And now you understand the math.
print("Training Step Complete")
print("="*50)
print()
print(" Forward pass: ✓")
print(" Backward pass: ✓")
print(" Optimization: ✓")
print()
print("One training iteration done.")
print("Repeat ~billions of times for a real LLM.")Training Step Complete
==================================================
Forward pass: ✓
Backward pass: ✓
Optimization: ✓
One training iteration done.
Repeat ~billions of times for a real LLM.