Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

AdamW Optimizer

The Final Step: Actually Learning

We’ve traced the entire forward pass through our tiny transformer. We’ve computed the loss. We’ve backpropagated gradients through every layer, all the way back to the embeddings.

Now we have ~2,600 gradients—one for each parameter in the model. Each gradient tells us: “If you increase this parameter slightly, the loss will change by approximately this much.”

The question is: what do we do with these gradients?

The simplest answer would be plain stochastic gradient descent (SGD):

θnew=θoldηLθ\theta_{\text{new}} = \theta_{\text{old}} - \eta \cdot \frac{\partial L}{\partial \theta}

Subtract the gradient (scaled by a learning rate η\eta), and you’re done.

But SGD has problems. Modern transformers use something much more sophisticated: AdamW.

Why SGD Isn’t Enough

Plain gradient descent has several failure modes:

1. Same learning rate for all parameters

Some parameters need large updates (they’re far from optimal). Others need tiny updates (they’re already close). SGD uses the same learning rate for everyone, which means you have to choose a rate that’s safe for the most sensitive parameters—and that’s too slow for the rest.

2. Noisy gradients cause oscillation

Gradients computed on mini-batches are noisy estimates of the true gradient. SGD follows each noisy gradient directly, which can cause the optimization to zigzag back and forth instead of heading straight toward the minimum.

3. Sensitive to learning rate

Too high a learning rate and training diverges. Too low and it takes forever. The right learning rate depends on the loss surface, which changes during training.

AdamW addresses all of these problems.

What AdamW Does

AdamW combines three powerful ideas:

1. Adaptive learning rates (from Adam)

Each parameter gets its own effective learning rate, based on the history of its gradients. Parameters with consistently large gradients get smaller learning rates (we’re already moving fast). Parameters with small gradients get larger learning rates (we need to push harder).

2. Momentum (from Adam)

Instead of following each noisy gradient directly, we track an exponential moving average of past gradients. This smooths out the noise and helps the optimizer maintain direction even when individual gradients are noisy.

3. Weight decay (the “W” in AdamW)

Regularization that shrinks weights toward zero. This prevents overfitting by penalizing large weights. AdamW applies weight decay directly to parameters, separate from the gradient update. (The original Adam paper applied weight decay through the gradient, which interacted poorly with the adaptive learning rates.)

import math

# AdamW hyperparameters (standard values from the literature)
learning_rate = 0.001    # α (alpha): base learning rate
beta1 = 0.9              # β₁: decay rate for first moment (momentum)
beta2 = 0.999            # β₂: decay rate for second moment (adaptive LR)
epsilon = 1e-8           # ε: small constant for numerical stability
weight_decay = 0.01      # λ: weight decay coefficient

print("AdamW Hyperparameters")
print("="*50)
print(f"  learning_rate (α) = {learning_rate}")
print(f"  beta1 (β₁)        = {beta1}")
print(f"  beta2 (β₂)        = {beta2}")
print(f"  epsilon (ε)       = {epsilon}")
print(f"  weight_decay (λ)  = {weight_decay}")
print()
print("These values are standard—most transformer training")
print("uses exactly these numbers.")
AdamW Hyperparameters
==================================================
  learning_rate (α) = 0.001
  beta1 (β₁)        = 0.9
  beta2 (β₂)        = 0.999
  epsilon (ε)       = 1e-08
  weight_decay (λ)  = 0.01

These values are standard—most transformer training
uses exactly these numbers.

The AdamW Algorithm

For each parameter θ\theta, AdamW maintains two state variables across training steps:

  • mmFirst moment estimate: An exponential moving average of the gradient (this is momentum)

  • vvSecond moment estimate: An exponential moving average of the squared gradient (this enables adaptive learning rates)

At each time step tt, given gradient gt=Lθg_t = \frac{\partial L}{\partial \theta}:


Step 1: Update biased moment estimates

mt=β1mt1+(1β1)gtm_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t
vt=β2vt1+(1β2)gt2v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2

The first moment mm accumulates gradient direction (momentum). The second moment vv accumulates gradient magnitude.


Step 2: Bias correction

m^t=mt1β1t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}
v^t=vt1β2t\hat{v}_t = \frac{v_t}{1 - \beta_2^t}

Since mm and vv are initialized to zero, they’re biased toward zero in early steps. This correction compensates for that bias.


Step 3: Weight decay

θθ(1αλ)\theta \leftarrow \theta \cdot (1 - \alpha \cdot \lambda)

Shrink the parameter toward zero by a small amount. This is regularization.


Step 4: Parameter update

θθαm^tv^t+ϵ\theta \leftarrow \theta - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

Update the parameter using the bias-corrected moments. The v^t\sqrt{\hat{v}_t} in the denominator creates the adaptive learning rate.

def adamw_update(theta, gradient, m, v, t, 
                 lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8, wd=0.01):
    """
    Perform one AdamW update step for a single parameter.
    
    Args:
        theta: Current parameter value
        gradient: Gradient of loss w.r.t. this parameter
        m: First moment estimate (momentum)
        v: Second moment estimate (for adaptive LR)
        t: Time step (starts at 1, increments each update)
        lr: Base learning rate (alpha)
        beta1: First moment decay rate
        beta2: Second moment decay rate
        eps: Small constant for numerical stability
        wd: Weight decay coefficient
    
    Returns:
        new_theta: Updated parameter value
        new_m: Updated first moment
        new_v: Updated second moment
    """
    # Step 1: Update biased moments
    m_new = beta1 * m + (1 - beta1) * gradient
    v_new = beta2 * v + (1 - beta2) * gradient**2
    
    # Step 2: Bias correction
    m_hat = m_new / (1 - beta1**t)
    v_hat = v_new / (1 - beta2**t)
    
    # Step 3: Weight decay (applied before gradient update)
    theta_decayed = theta * (1 - lr * wd)
    
    # Step 4: Parameter update
    theta_new = theta_decayed - lr * m_hat / (math.sqrt(v_hat) + eps)
    
    return theta_new, m_new, v_new

Worked Example: Updating One Parameter

Let’s walk through a complete AdamW update for a single parameter. We’ll update the first element of the <BOS> token’s embedding.

Suppose:

  • Current value: θ=0.024634\theta = 0.024634

  • Gradient: g=0.352893g = -0.352893 (negative means increasing θ\theta would decrease loss)

  • This is time step t=1t = 1 (first update)

  • Moments initialized to zero: m0=0m_0 = 0, v0=0v_0 = 0

# Initial state for our example parameter
theta = 0.024634      # Current value of E_token[<BOS>][0]
g = -0.352893         # Gradient (negative = we should increase theta)
m_prev = 0.0          # First moment (initialized to 0)
v_prev = 0.0          # Second moment (initialized to 0)
t = 1                 # Time step (first update)

print("Initial State")
print("="*50)
print(f"  Parameter θ     = {theta:.6f}")
print(f"  Gradient g      = {g:.6f}")
print(f"  First moment m  = {m_prev:.6f}")
print(f"  Second moment v = {v_prev:.6f}")
print(f"  Time step t     = {t}")
Initial State
==================================================
  Parameter θ     = 0.024634
  Gradient g      = -0.352893
  First moment m  = 0.000000
  Second moment v = 0.000000
  Time step t     = 1
# Step 1: Update biased moments
m_new = beta1 * m_prev + (1 - beta1) * g
v_new = beta2 * v_prev + (1 - beta2) * g**2

print("Step 1: Update Biased Moment Estimates")
print("="*60)
print()
print("First moment (momentum):")
print(f"  m₁ = β₁ × m₀ + (1 - β₁) × g")
print(f"     = {beta1} × {m_prev} + {1-beta1:.1f} × {g:.6f}")
print(f"     = {m_new:.6f}")
print()
print("Second moment (for adaptive LR):")
print(f"  v₁ = β₂ × v₀ + (1 - β₂) × g²")
print(f"     = {beta2} × {v_prev} + {1-beta2:.3f} × ({g:.6f})²")
print(f"     = {beta2} × {v_prev} + {1-beta2:.3f} × {g**2:.6f}")
print(f"     = {v_new:.9f}")
Step 1: Update Biased Moment Estimates
============================================================

First moment (momentum):
  m₁ = β₁ × m₀ + (1 - β₁) × g
     = 0.9 × 0.0 + 0.1 × -0.352893
     = -0.035289

Second moment (for adaptive LR):
  v₁ = β₂ × v₀ + (1 - β₂) × g²
     = 0.999 × 0.0 + 0.001 × (-0.352893)²
     = 0.999 × 0.0 + 0.001 × 0.124533
     = 0.000124533
# Step 2: Bias correction
m_hat = m_new / (1 - beta1**t)
v_hat = v_new / (1 - beta2**t)

print("Step 2: Bias Correction")
print("="*60)
print()
print("Bias-corrected first moment:")
print(f"  m̂₁ = m₁ / (1 - β₁^t)")
print(f"     = {m_new:.6f} / (1 - {beta1}^{t})")
print(f"     = {m_new:.6f} / {1 - beta1**t:.1f}")
print(f"     = {m_hat:.6f}")
print()
print("Bias-corrected second moment:")
print(f"  v̂₁ = v₁ / (1 - β₂^t)")
print(f"     = {v_new:.9f} / (1 - {beta2}^{t})")
print(f"     = {v_new:.9f} / {1 - beta2**t:.3f}")
print(f"     = {v_hat:.6f}")
print()
print(f"Note: At t=1, m̂ equals the gradient ({m_hat:.6f} ≈ {g:.6f}).")
print(f"This is expected—we haven't accumulated any history yet.")
Step 2: Bias Correction
============================================================

Bias-corrected first moment:
  m̂₁ = m₁ / (1 - β₁^t)
     = -0.035289 / (1 - 0.9^1)
     = -0.035289 / 0.1
     = -0.352893

Bias-corrected second moment:
  v̂₁ = v₁ / (1 - β₂^t)
     = 0.000124533 / (1 - 0.999^1)
     = 0.000124533 / 0.001
     = 0.124533

Note: At t=1, m̂ equals the gradient (-0.352893 ≈ -0.352893).
This is expected—we haven't accumulated any history yet.
# Step 3: Weight decay
decay_factor = 1 - learning_rate * weight_decay
theta_decayed = theta * decay_factor

print("Step 3: Weight Decay")
print("="*60)
print()
print(f"  θ_decayed = θ × (1 - α × λ)")
print(f"           = {theta:.6f} × (1 - {learning_rate} × {weight_decay})")
print(f"           = {theta:.6f} × {decay_factor:.5f}")
print(f"           = {theta_decayed:.9f}")
print()
print(f"Weight shrinkage: {theta - theta_decayed:.9f}")
print(f"(Tiny! Weight decay is a gentle regularization.)") 
Step 3: Weight Decay
============================================================

  θ_decayed = θ × (1 - α × λ)
           = 0.024634 × (1 - 0.001 × 0.01)
           = 0.024634 × 0.99999
           = 0.024633754

Weight shrinkage: 0.000000246
(Tiny! Weight decay is a gentle regularization.)
# Step 4: Parameter update
denominator = math.sqrt(v_hat) + epsilon
adaptive_lr = learning_rate / denominator
update_amount = learning_rate * m_hat / denominator
theta_new = theta_decayed - update_amount

print("Step 4: Parameter Update")
print("="*60)
print()
print("Compute the adaptive learning rate:")
print(f"  Effective LR = α / (√v̂ + ε)")
print(f"              = {learning_rate} / (√{v_hat:.6f} + {epsilon})")
print(f"              = {learning_rate} / ({math.sqrt(v_hat):.6f} + {epsilon})")
print(f"              = {learning_rate} / {denominator:.6f}")
print(f"              = {adaptive_lr:.6f}")
print()
print(f"This is {adaptive_lr/learning_rate:.2f}× the base learning rate!")
print()
print("Compute the update:")
print(f"  Update = α × m̂ / (√v̂ + ε)")
print(f"         = {learning_rate} × {m_hat:.6f} / {denominator:.6f}")
print(f"         = {update_amount:.6f}")
print()
print("Apply to parameter:")
print(f"  θ_new = θ_decayed - update")
print(f"        = {theta_decayed:.6f} - ({update_amount:.6f})")
print(f"        = {theta_new:.6f}")
Step 4: Parameter Update
============================================================

Compute the adaptive learning rate:
  Effective LR = α / (√v̂ + ε)
              = 0.001 / (√0.124533 + 1e-08)
              = 0.001 / (0.352893 + 1e-08)
              = 0.001 / 0.352893
              = 0.002834

This is 2.83× the base learning rate!

Compute the update:
  Update = α × m̂ / (√v̂ + ε)
         = 0.001 × -0.352893 / 0.352893
         = -0.001000

Apply to parameter:
  θ_new = θ_decayed - update
        = 0.024634 - (-0.001000)
        = 0.025634
# Summary
print("="*60)
print("SUMMARY: One AdamW Update")
print("="*60)
print()
print(f"Parameter: E_token[<BOS>][0]")
print(f"  Before:  θ = {theta:.6f}")
print(f"  After:   θ = {theta_new:.6f}")
print(f"  Change:      {theta_new - theta:+.6f}")
print()
print(f"The gradient was negative ({g:.6f}), meaning:")
print(f"  'Increasing θ would decrease the loss.'")
print(f"So AdamW increased θ. The model learned slightly!")
============================================================
SUMMARY: One AdamW Update
============================================================

Parameter: E_token[<BOS>][0]
  Before:  θ = 0.024634
  After:   θ = 0.025634
  Change:      +0.001000

The gradient was negative (-0.352893), meaning:
  'Increasing θ would decrease the loss.'
So AdamW increased θ. The model learned slightly!

Why the Adaptive Learning Rate Matters

Notice that our effective learning rate was ~2.83× the base learning rate. This happened because:

Effective LR=αv^+ϵ=0.0010.353+1080.00283\text{Effective LR} = \frac{\alpha}{\sqrt{\hat{v}} + \epsilon} = \frac{0.001}{0.353 + 10^{-8}} \approx 0.00283

The denominator v^\sqrt{\hat{v}} is essentially the RMS (root mean square) of recent gradients for this parameter.

  • Large gradients → Large v^\hat{v} → Small effective LR → Cautious updates

  • Small gradients → Small v^\hat{v} → Large effective LR → Aggressive updates

This automatically adapts to each parameter’s gradient scale. Parameters deep in the network often have smaller gradients (due to the chain rule), but Adam compensates by giving them larger effective learning rates.

# Demonstrate how v_hat affects the effective learning rate
print("How Gradient Magnitude Affects Learning Rate")
print("="*60)
print()
print(f"{'|gradient|':>12} {'v̂ (≈g²)':>12} {'Effective LR':>15} {'Multiplier':>12}")
print("-"*60)

test_grads = [0.001, 0.01, 0.1, 0.5, 1.0, 5.0]
for g_test in test_grads:
    v_test = g_test**2  # Simplified: v ≈ g² at t=1
    eff_lr = learning_rate / (math.sqrt(v_test) + epsilon)
    mult = eff_lr / learning_rate
    print(f"{g_test:>12.3f} {v_test:>12.6f} {eff_lr:>15.6f} {mult:>12.2f}×")

print()
print("Small gradients get large multipliers; large gradients get small ones.")
print("This is what makes Adam 'adaptive'.")
How Gradient Magnitude Affects Learning Rate
============================================================

  |gradient|     v̂ (≈g²)    Effective LR   Multiplier
------------------------------------------------------------
       0.001     0.000001        0.999990       999.99×
       0.010     0.000100        0.100000       100.00×
       0.100     0.010000        0.010000        10.00×
       0.500     0.250000        0.002000         2.00×
       1.000     1.000000        0.001000         1.00×
       5.000    25.000000        0.000200         0.20×

Small gradients get large multipliers; large gradients get small ones.
This is what makes Adam 'adaptive'.

Updating All Parameters

We apply this same AdamW update to every single parameter in the model. Let’s count them:

ComponentShapeCount
Token embeddings EtokenE_{token}[6, 16]96
Position embeddings EposE_{pos}[5, 16]80
Attention WQW_Q (×2 heads)[16, 8] × 2256
Attention WKW_K (×2 heads)[16, 8] × 2256
Attention WVW_V (×2 heads)[16, 8] × 2256
Attention WOW_O[16, 16]256
FFN W1W_1[64, 16]1,024
FFN b1b_1[64]64
FFN W2W_2[16, 64]1,024
FFN b2b_2[16]16
Layer norm γ\gamma[16]16
Layer norm β\beta[16]16
LM head WlmW_{lm}[6, 16]96
Total~3,456

Each parameter has its own mm and vv state, which means AdamW needs to store 2× as many values as there are parameters. This is the memory cost of adaptive optimizers.

# Parameter count breakdown
param_counts = {
    "Token embeddings": 6 * 16,
    "Position embeddings": 5 * 16,
    "W_Q (2 heads)": 2 * 16 * 8,
    "W_K (2 heads)": 2 * 16 * 8,
    "W_V (2 heads)": 2 * 16 * 8,
    "W_O": 16 * 16,
    "FFN W_1": 64 * 16,
    "FFN b_1": 64,
    "FFN W_2": 16 * 64,
    "FFN b_2": 16,
    "Layer norm γ": 16,
    "Layer norm β": 16,
    "LM head W_lm": 6 * 16,
}

print("Parameter Count")
print("="*40)
total = 0
for name, count in param_counts.items():
    print(f"  {name:25s}: {count:>6,}")
    total += count
print("-"*40)
print(f"  {'Total':25s}: {total:>6,}")
print()
print(f"AdamW state (m and v): {2 * total:,} values")
Parameter Count
========================================
  Token embeddings         :     96
  Position embeddings      :     80
  W_Q (2 heads)            :    256
  W_K (2 heads)            :    256
  W_V (2 heads)            :    256
  W_O                      :    256
  FFN W_1                  :  1,024
  FFN b_1                  :     64
  FFN W_2                  :  1,024
  FFN b_2                  :     16
  Layer norm γ             :     16
  Layer norm β             :     16
  LM head W_lm             :     96
----------------------------------------
  Total                    :  3,456

AdamW state (m and v): 6,912 values

The Complete Training Loop

We’ve now traced through one complete training step:

  1. Forward pass (Notebooks 01-07)

    • Tokenize input → Embeddings → Attention → FFN → Layer Norm → Logits → Loss

  2. Backward pass (Notebooks 08-09)

    • Compute Llogits\frac{\partial L}{\partial \text{logits}} → Backprop through each layer → Gradients for all parameters

  3. Optimization (This notebook)

    • Update each parameter using AdamW

In pseudocode:

# Initialize model parameters randomly
# Initialize Adam state (m=0, v=0 for each parameter)
t = 0

for epoch in range(num_epochs):
    for batch in training_data:
        t += 1
        
        # Forward pass
        predictions = model.forward(batch)
        loss = cross_entropy(predictions, targets)
        
        # Backward pass
        gradients = model.backward(loss)
        
        # Optimization
        for param in model.parameters:
            param.value, param.m, param.v = adamw_update(
                param.value, param.gradient, param.m, param.v, t
            )

Repeat this loop millions of times. Each iteration, the loss gets smaller. The model gets better.

Scaling Up

Our tiny model has ~3,500 parameters. Real language models have:

ModelParameters
GPT-2 Small124 million
GPT-2 Large774 million
GPT-3175 billion
LLaMA 2 70B70 billion
Claude 3Unknown (but large)

The math is identical. Every one of those billions of parameters goes through the same AdamW update we just computed. The difference is scale—and the engineering required to parallelize the computation across thousands of GPUs.

But the algorithm is exactly what we’ve shown: forward pass, backward pass, AdamW update. Repeat.

What We’ve Accomplished

We calculated—by hand, with explicit numbers—a complete training step through a transformer:

Forward pass:

  • Tokenization and embedding lookup

  • Q/K/V projections for attention

  • Scaled dot-product attention with causal masking

  • Multi-head concatenation and output projection

  • Feed-forward network with GELU activation

  • Layer normalization with residual connections

  • Output projection and softmax

  • Cross-entropy loss computation

Backward pass:

  • Gradient of loss with respect to logits

  • Backpropagation through every layer using the chain rule

  • Gradients for all ~3,500 parameters

Optimization:

  • AdamW update with momentum, adaptive learning rates, and weight decay

Nothing was hidden. No magic. Just math.

Closing Thoughts

You’ve made it through the entire pipeline.

You’ve seen every matrix multiplication, every activation function, every gradient calculation, every weight update. You understand how a transformer processes text, why attention works, what backpropagation computes, and how optimization nudges parameters toward better predictions.

When someone says “a transformer learns by gradient descent,” you now know exactly what that means—down to the individual floating-point operations.

This is the difference between knowing about something and truly understanding it. You don’t just know that transformers use attention mechanisms—you’ve computed the attention scores yourself. You don’t just know that neural networks use backpropagation—you’ve traced the chain rule through every layer.

The next time you use GPT, Claude, or any language model, remember: under the hood, it’s doing exactly what we did here. Billions of times larger. Trillions of times repeated. But the same fundamental math.

Transformers aren’t magic. They’re just math.

And now you understand the math.

print("Training Step Complete")
print("="*50)
print()
print("  Forward pass:  ✓")
print("  Backward pass: ✓")
print("  Optimization:  ✓")
print()
print("One training iteration done.")
print("Repeat ~billions of times for a real LLM.")
Training Step Complete
==================================================

  Forward pass:  ✓
  Backward pass: ✓
  Optimization:  ✓

One training iteration done.
Repeat ~billions of times for a real LLM.