Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Loss Gradients

The Backward Pass Begins

The forward pass is done. We fed “I like transformers” through our tiny transformer and got a loss of about 1.9—essentially random guessing. The model has no idea what it’s doing.

Now we need to fix that. We need to figure out how to adjust the ~2,600 parameters so the model does better next time.

The tool for this is backpropagation—the algorithm that computes how much each parameter contributed to the final loss. Once we know that, we can nudge each parameter in the direction that reduces the loss.

The big picture:

For every parameter θ\theta in the model, we want to compute:

Lθ\frac{\partial L}{\partial \theta}

This is the gradient—it tells us how much the loss would change if we increased θ\theta by a tiny amount. Positive gradient means increasing θ\theta increases loss (bad). Negative gradient means increasing θ\theta decreases loss (good).

The strategy:

We work backward from the loss. First we compute how the loss depends on the logits. Then how the logits depend on the hidden states. Then how the hidden states depend on the FFN. And so on, all the way back to the embeddings.

This is the chain rule in action: Lθ=Lyyθ\frac{\partial L}{\partial \theta} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial \theta}

Cross-Entropy + Softmax: A Beautiful Gradient

Let’s start at the very end of the computation graph. Our loss is:

L=logP(target)L = -\log P(\text{target})

Where PP comes from softmax:

P(i)=exp(logiti)jexp(logitj)P(i) = \frac{\exp(\text{logit}_i)}{\sum_j \exp(\text{logit}_j)}

We need Llogiti\frac{\partial L}{\partial \text{logit}_i} for each vocabulary token ii.

The derivation:

This requires some calculus, but the result is remarkably clean. Let’s work through it.

For the loss L=logP(t)L = -\log P(t) where tt is the target:

Llogiti=LP(t)P(t)logiti\frac{\partial L}{\partial \text{logit}_i} = \frac{\partial L}{\partial P(t)} \cdot \frac{\partial P(t)}{\partial \text{logit}_i}

The first term is straightforward:

LP(t)=1P(t)\frac{\partial L}{\partial P(t)} = -\frac{1}{P(t)}

The second term (softmax derivative) depends on whether i=ti = t:

  • If i=ti = t: P(t)logitt=P(t)(1P(t))\frac{\partial P(t)}{\partial \text{logit}_t} = P(t)(1 - P(t))

  • If iti \neq t: P(t)logiti=P(t)P(i)\frac{\partial P(t)}{\partial \text{logit}_i} = -P(t) \cdot P(i)

Combining these:

  • If i=ti = t: Llogitt=1P(t)P(t)(1P(t))=(1P(t))=P(t)1\frac{\partial L}{\partial \text{logit}_t} = -\frac{1}{P(t)} \cdot P(t)(1 - P(t)) = -(1 - P(t)) = P(t) - 1

  • If iti \neq t: Llogiti=1P(t)(P(t)P(i))=P(i)\frac{\partial L}{\partial \text{logit}_i} = -\frac{1}{P(t)} \cdot (-P(t) \cdot P(i)) = P(i)

The final formula:

Llogiti=P(i)1[i=t]\frac{\partial L}{\partial \text{logit}_i} = P(i) - \mathbb{1}[i = t]

Where 1[i=t]\mathbb{1}[i = t] is the indicator function: 1 if ii equals the target, 0 otherwise.

This is one of the most elegant results in machine learning. The gradient is just: predicted probability minus target probability.

import random
import math

random.seed(42)

VOCAB_SIZE = 6
D_MODEL = 16
D_FF = 64
MAX_SEQ_LEN = 5
NUM_HEADS = 2
D_K = D_MODEL // NUM_HEADS

TOKEN_NAMES = ["<PAD>", "<BOS>", "<EOS>", "I", "like", "transformers"]
# For this notebook, we'll use pre-computed probabilities from the forward pass
# These are approximately what a randomly initialized model would produce
probs = [
    [0.1785, 0.2007, 0.1759, 0.1254, 0.1563, 0.1632],  # position 0: <BOS>
    [0.1836, 0.1969, 0.1805, 0.1233, 0.1500, 0.1657],  # position 1: I
    [0.1795, 0.2050, 0.1782, 0.1207, 0.1437, 0.1728],  # position 2: like
    [0.1855, 0.2017, 0.1771, 0.1271, 0.1391, 0.1695],  # position 3: transformers
]

# Target tokens (what the model should predict at each position)
targets = [3, 4, 5, 2]  # I, like, transformers, <EOS>

# Current tokens (input at each position)
tokens = [1, 3, 4, 5, 2]  # <BOS>, I, like, transformers, <EOS>

print("Forward pass predictions:")
print("="*60)
print()
for i, t in enumerate(targets):
    prob_correct = probs[i][t]
    print(f"Position {i}: {TOKEN_NAMES[tokens[i]]:12s} → should predict {TOKEN_NAMES[t]:12s} (P = {prob_correct:.4f})")
Forward pass predictions:
============================================================

Position 0: <BOS>        → should predict I            (P = 0.1254)
Position 1: I            → should predict like         (P = 0.1500)
Position 2: like         → should predict transformers (P = 0.1728)
Position 3: transformers → should predict <EOS>        (P = 0.1771)

Computing the Gradient

Now let’s compute Llogit\frac{\partial L}{\partial \text{logit}} for each position and each vocabulary token.

The formula is simple:

  • For the target token: gradient = P(target)1P(\text{target}) - 1 (always negative)

  • For all other tokens: gradient = P(token)P(\text{token}) (always positive)

def compute_loss_gradient(probs, target):
    """
    Compute gradient of cross-entropy loss w.r.t. logits.
    
    dL/dlogit[i] = P(i) - 1  if i == target
    dL/dlogit[i] = P(i)      otherwise
    
    Args:
        probs: Probability distribution from softmax [vocab_size]
        target: Index of the correct token
    
    Returns:
        Gradient vector [vocab_size]
    """
    grad = probs.copy()
    grad[target] -= 1.0  # Subtract 1 from the target position
    return grad

# Compute gradients for all positions
dL_dlogits = []
for i in range(len(targets)):
    grad = compute_loss_gradient(probs[i], targets[i])
    dL_dlogits.append(grad)

print("Loss Gradients w.r.t. Logits (dL/dlogit)")
print("="*75)
print()
print(f"{'Position':<12} {'<PAD>':>9} {'<BOS>':>9} {'<EOS>':>9} {'I':>9} {'like':>9} {'trans':>9}")
print("-"*75)
for i, grad in enumerate(dL_dlogits):
    target = targets[i]
    row = f"{TOKEN_NAMES[tokens[i]]:<12}"
    for j, g in enumerate(grad):
        marker = "*" if j == target else " "
        row += f" {g:>8.4f}{marker}"
    print(row)
print()
print("* marks the target token (negative gradient)")
Loss Gradients w.r.t. Logits (dL/dlogit)
===========================================================================

Position         <PAD>     <BOS>     <EOS>         I      like     trans
---------------------------------------------------------------------------
<BOS>          0.1785    0.2007    0.1759   -0.8746*   0.1563    0.1632 
I              0.1836    0.1969    0.1805    0.1233   -0.8500*   0.1657 
like           0.1795    0.2050    0.1782    0.1207    0.1437   -0.8272*
transformers   0.1855    0.2017   -0.8229*   0.1271    0.1391    0.1695 

* marks the target token (negative gradient)

Understanding the Gradients

Let’s break down position 0 in detail. The model sees <BOS> and should predict I.

print("Detailed: Position 0 (<BOS> → should predict I)")
print("="*70)
print()

print("Step 1: Current probabilities from softmax")
print("-"*50)
for j, name in enumerate(TOKEN_NAMES):
    marker = "← TARGET" if j == targets[0] else ""
    print(f"  P({name:12s}) = {probs[0][j]:.4f}  {marker}")
print(f"  Sum = {sum(probs[0]):.4f}")
print()

print("Step 2: Compute gradients using formula: dL/dlogit = P(i) - 1[i==target]")
print("-"*70)
for j, name in enumerate(TOKEN_NAMES):
    is_target = 1 if j == targets[0] else 0
    grad = probs[0][j] - is_target
    print(f"  dL/dlogit[{name:12s}] = {probs[0][j]:.4f} - {is_target} = {grad:>8.4f}")
Detailed: Position 0 (<BOS> → should predict I)
======================================================================

Step 1: Current probabilities from softmax
--------------------------------------------------
  P(<PAD>       ) = 0.1785  
  P(<BOS>       ) = 0.2007  
  P(<EOS>       ) = 0.1759  
  P(I           ) = 0.1254  ← TARGET
  P(like        ) = 0.1563  
  P(transformers) = 0.1632  
  Sum = 1.0000

Step 2: Compute gradients using formula: dL/dlogit = P(i) - 1[i==target]
----------------------------------------------------------------------
  dL/dlogit[<PAD>       ] = 0.1785 - 0 =   0.1785
  dL/dlogit[<BOS>       ] = 0.2007 - 0 =   0.2007
  dL/dlogit[<EOS>       ] = 0.1759 - 0 =   0.1759
  dL/dlogit[I           ] = 0.1254 - 1 =  -0.8746
  dL/dlogit[like        ] = 0.1563 - 0 =   0.1563
  dL/dlogit[transformers] = 0.1632 - 0 =   0.1632

Why These Gradients Make Sense

Key insight: Gradients point in the direction of increasing loss.

When we do gradient descent, we subtract the gradient (times a learning rate):

logitnew=logitoldηLlogit\text{logit}_{\text{new}} = \text{logit}_{\text{old}} - \eta \cdot \frac{\partial L}{\partial \text{logit}}

So:

For the correct token (gradient = P10.87P - 1 \approx -0.87):

  • We subtract a negative number

  • That’s the same as adding

  • The logit increases

  • Higher logit → higher probability ✓

For incorrect tokens (gradient = P+0.17P \approx +0.17):

  • We subtract a positive number

  • The logit decreases

  • Lower logit → lower probability ✓

Gradient descent naturally pushes the correct answer’s probability up and everything else down.

# Demonstrate what happens with a gradient update
learning_rate = 0.1

# Pretend these are the logits (we'll make up some values)
logits_before = [-0.5, 0.2, -0.3, -0.8, -0.2, 0.1]  # Position 0

print("Effect of one gradient descent step (learning rate = 0.1)")
print("="*70)
print()
print(f"{'Token':<12} {'Logit Before':>12} {'Gradient':>10} {'Logit After':>12} {'Change':>10}")
print("-"*70)

logits_after = []
for j, name in enumerate(TOKEN_NAMES):
    grad = dL_dlogits[0][j]
    new_logit = logits_before[j] - learning_rate * grad
    logits_after.append(new_logit)
    change = new_logit - logits_before[j]
    marker = "← TARGET" if j == targets[0] else ""
    print(f"{name:<12} {logits_before[j]:>12.4f} {grad:>10.4f} {new_logit:>12.4f} {change:>+10.4f}  {marker}")

print()
print("The target token's logit increased; all others decreased.")
Effect of one gradient descent step (learning rate = 0.1)
======================================================================

Token        Logit Before   Gradient  Logit After     Change
----------------------------------------------------------------------
<PAD>             -0.5000     0.1785      -0.5179    -0.0179  
<BOS>              0.2000     0.2007       0.1799    -0.0201  
<EOS>             -0.3000     0.1759      -0.3176    -0.0176  
I                 -0.8000    -0.8746      -0.7125    +0.0875  ← TARGET
like              -0.2000     0.1563      -0.2156    -0.0156  
transformers       0.1000     0.1632       0.0837    -0.0163  

The target token's logit increased; all others decreased.

Mathematical Property: Gradients Sum to Zero

Here’s a nice verification we can do. The gradients at each position should sum to zero:

i=0V1Llogiti=i=0V1P(i)1=(i=0V1P(i))1=11=0\sum_{i=0}^{V-1} \frac{\partial L}{\partial \text{logit}_i} = \sum_{i=0}^{V-1} P(i) - 1 = \left(\sum_{i=0}^{V-1} P(i)\right) - 1 = 1 - 1 = 0

This makes sense: softmax is translation-invariant. Adding a constant to all logits doesn’t change the probabilities (it cancels out in the normalization). So the gradient with respect to “shift all logits equally” should be zero.

print("Verification: Gradients sum to zero at each position")
print("="*50)
print()
for i, grad in enumerate(dL_dlogits):
    grad_sum = sum(grad)
    status = "✓" if abs(grad_sum) < 1e-6 else f"✗ (off by {grad_sum:.2e})"
    print(f"Position {i} ({TOKEN_NAMES[tokens[i]]:12s}): sum = {grad_sum:>12.10f}  {status}")
Verification: Gradients sum to zero at each position
==================================================

Position 0 (<BOS>       ): sum = -0.0000000000  ✓
Position 1 (I           ): sum = 0.0000000000  ✓
Position 2 (like        ): sum = -0.0001000000  ✗ (off by -1.00e-04)
Position 3 (transformers): sum = 0.0000000000  ✓

Gradient Magnitude Tells a Story

Notice the magnitudes of the gradients:

  • Target gradient: ~-0.85 (large negative)

  • Non-target gradients: ~+0.17 (small positive)

The target gradient is much larger in magnitude. Why?

Because P(target)1=1P(target)|P(\text{target}) - 1| = 1 - P(\text{target}) is large when the model is wrong. Our model assigns only ~15% probability to the correct answer, so 10.15=0.851 - 0.15 = 0.85.

As training progresses:

  • If P(target)0.9P(\text{target}) \to 0.9: gradient becomes 0.91=0.10.9 - 1 = -0.1 (small)

  • If P(target)0.99P(\text{target}) \to 0.99: gradient becomes 0.991=0.010.99 - 1 = -0.01 (tiny)

The gradient naturally gets smaller as the model improves—the learning signal weakens when there’s less to learn. This is a form of adaptive learning.

# Show how gradient magnitude changes with model confidence
print("How gradient magnitude depends on model confidence")
print("="*55)
print()
print(f"{'P(target)':>12} {'Target Gradient':>18} {'Interpretation':<25}")
print("-"*55)

test_probs = [0.01, 0.10, 0.25, 0.50, 0.75, 0.90, 0.99]
for p in test_probs:
    grad = p - 1
    if p < 0.2:
        interp = "Very wrong, strong signal"
    elif p < 0.5:
        interp = "Uncertain, moderate signal"
    elif p < 0.8:
        interp = "Getting it, weaker signal"
    else:
        interp = "Confident, tiny signal"
    print(f"{p:>12.2f} {grad:>18.4f} {interp:<25}")
How gradient magnitude depends on model confidence
=======================================================

   P(target)    Target Gradient Interpretation           
-------------------------------------------------------
        0.01            -0.9900 Very wrong, strong signal
        0.10            -0.9000 Very wrong, strong signal
        0.25            -0.7500 Uncertain, moderate signal
        0.50            -0.5000 Getting it, weaker signal
        0.75            -0.2500 Getting it, weaker signal
        0.90            -0.1000 Confident, tiny signal   
        0.99            -0.0100 Confident, tiny signal   

The Gradient Tensor Shape

Let’s be explicit about what we’ve computed:

LlogitsR4×6\frac{\partial L}{\partial \text{logits}} \in \mathbb{R}^{4 \times 6}
  • 4 positions (we don’t predict after <EOS>)

  • 6 vocabulary tokens

Each entry tells us how much the loss would change if we increased that particular logit by a tiny amount.

print(f"Gradient tensor shape: [{len(dL_dlogits)}, {len(dL_dlogits[0])}]")
print(f"  - 4 positions (0 through 3, predicting tokens 1 through 4)")
print(f"  - 6 vocabulary tokens")
print(f"  - Total: 24 gradient values")
Gradient tensor shape: [4, 6]
  - 4 positions (0 through 3, predicting tokens 1 through 4)
  - 6 vocabulary tokens
  - Total: 24 gradient values

What’s Next: Backpropagating Further

We’ve computed Llogits\frac{\partial L}{\partial \text{logits}}. This tells us how the loss depends on the raw prediction scores.

But we can’t directly modify the logits—they’re computed from earlier layers:

logits=hiddenWlmT\text{logits} = \text{hidden} \cdot W_{lm}^T

To train the model, we need gradients for:

  1. WlmW_{lm} (the language modeling head weights) - so we can update them

  2. hidden states - so we can backpropagate further into the transformer

And then we’ll continue backward through:

  • Layer normalization

  • Residual connections

  • Feed-forward network

  • Multi-head attention

  • Q/K/V projections

  • Embeddings

The next notebook continues the backward pass through these layers.

# Store gradients for next notebook
grad_loss_data = {
    'dL_dlogits': dL_dlogits,
    'probs': probs,
    'targets': targets,
    'tokens': tokens
}

print("Summary: Loss Gradients")
print("="*50)
print(f"Shape: [4, 6] (4 positions × 6 vocab tokens)")
print(f"Formula: dL/dlogit[i] = P(i) - 1[i==target]")
print(f"Property: Each row sums to 0")
print()
print("Ready for backpropagation through the model.")
Summary: Loss Gradients
==================================================
Shape: [4, 6] (4 positions × 6 vocab tokens)
Formula: dL/dlogit[i] = P(i) - 1[i==target]
Property: Each row sums to 0

Ready for backpropagation through the model.