Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

RLHF Training Dynamics

You’ve seen the pieces — policy models, reward models, value networks, PPO. Now let’s watch them dance together.

This is where RLHF gets real. We’re going to break down exactly what happens during training, step by step, no hand-waving. By the end, you’ll understand the heartbeat of RLHF: generate, evaluate, improve, repeat.

The Two-Phase Dance

RLHF training has a rhythm. Each iteration does two things:

Phase 1: Rollout — Your policy generates responses. You score them. You remember what happened.

Phase 2: Update — You use that data to make your policy better. Then you do it again.

Think of it like learning to cook. First you make a dish (rollout). Then you taste it, figure out what worked and what didn’t, and adjust your technique (update). Rinse and repeat until you’re a chef.

The key insight? You don’t learn while you’re generating. You generate first, then learn from what you generated. This separation is what makes PPO stable.

What’s a Rollout, Anyway?

Here’s a word you’ll hear constantly in RL: rollout.

A rollout is just... running your policy and recording what happens. That’s it.

In RLHF terms:

  • You give your model some prompts

  • It generates responses (one token at a time)

  • You score those responses with your reward model

  • You ask your value network what it thinks

  • You save everything to memory

Think of it like recording a practice session. You’re not evaluating yourself mid-sentence — you finish your full response, then you look at what you did. The recording (the rollout) lets you study your performance afterward.

In each rollout we store:

import torch
from dataclasses import dataclass
from typing import List, Optional

@dataclass
class RolloutBatch:
    """Everything we need to remember from generating responses."""
    
    # What we started with
    query_tensors: torch.Tensor       # The prompts we gave the model
    response_tensors: torch.Tensor    # What the model generated
    
    # What the models thought at generation time
    logprobs: torch.Tensor            # How confident was our policy?
    ref_logprobs: torch.Tensor        # How confident was the reference model?
    values: torch.Tensor              # What did the value network predict?
    
    # What we computed afterward
    rewards: torch.Tensor             # Reward model scores (the "grades")
    advantages: torch.Tensor          # "How much better than expected?" (we'll explain this!)
    returns: torch.Tensor             # Target values for training the value network

# Let's visualize what a rollout looks like
print("A RolloutBatch is like a student's exam:")
print()
print("  query_tensors     →  The questions")
print("  response_tensors  →  The student's answers")
print("  logprobs          →  Student's confidence in each answer")
print("  ref_logprobs      →  What a baseline student would've done")
print("  values            →  Student's prediction of their grade")
print("  rewards           →  The actual grade from the teacher")
print("  advantages        →  Did they do better or worse than expected?")
print("  returns           →  What the grade *should've* been")
print()
print("We store all of this so we can learn from it later.")
A RolloutBatch is like a student's exam:

  query_tensors     →  The questions
  response_tensors  →  The student's answers
  logprobs          →  Student's confidence in each answer
  ref_logprobs      →  What a baseline student would've done
  values            →  Student's prediction of their grade
  rewards           →  The actual grade from the teacher
  advantages        →  Did they do better or worse than expected?
  returns           →  What the grade *should've* been

We store all of this so we can learn from it later.

Advantages: The Key Insight

Now we get to the really clever bit. Let’s talk about advantages.

Here’s the core question in reinforcement learning: which actions were actually good?

You might think “just use the rewards!” But that’s too naive. Imagine you’re playing basketball, and your team wins 100-98. Every action you took gets associated with “winning” — but some of your shots were terrible and you just got lucky.

What you really want to know is: which actions were better than expected?

That’s an advantage. It’s not the absolute reward. It’s the surprise.

If your value network predicted you’d get a reward of 5, and you got 7, your advantage is +2. You did better than expected! Reinforce that behavior.

If it predicted 8 and you got 7, your advantage is -1. You did worse than expected. Maybe don’t do that again.

This is where GAE comes in — Generalized Advantage Estimation. (Fancy name, but the intuition is simple: it’s a smart way to compute these “better than expected” signals.)

The formula looks scary, but let’s break it down:

def compute_gae(
    rewards: torch.Tensor,
    values: torch.Tensor,
    gamma: float = 0.99,
    lam: float = 0.95
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute Generalized Advantage Estimation (GAE).
    
    This answers: "How much better/worse was each action compared to what we expected?"
    
    The math (don't panic):
        A_t = Σ (γλ)^l * δ_{t+l}
    
    Where δ_t is the "TD error" (Temporal Difference error):
        δ_t = r_t + γV(s_{t+1}) - V(s_t)
    
    In English:
        δ_t = "actual reward" + "discounted future value" - "predicted value"
        δ_t = "what happened" - "what we expected"
    
    Then we sum up these errors over time, with exponential decay (γλ).
    
    Args:
        rewards: What actually happened, shape (batch, seq_len)
        values: What we predicted would happen, shape (batch, seq_len)
        gamma: Discount factor (how much we care about future rewards)
                Higher = "future matters", lower = "only now matters"
        lam: GAE lambda (bias-variance knob, 0.95 is standard)
                Higher = trust multi-step returns, lower = trust single-step
    
    Returns:
        advantages: "Better or worse than expected?" for each timestep
        returns: What the value *should've* been (for training value network)
    """
    batch_size, seq_len = rewards.shape
    advantages = torch.zeros_like(rewards)
    last_gae = 0
    
    # We go backwards through time (from end to beginning)
    # Why? Because we need to know the future to compute "expected future value"
    for t in reversed(range(seq_len)):
        if t == seq_len - 1:
            next_value = 0  # End of sequence, no future
        else:
            next_value = values[:, t + 1]
        
        # TD error: "what we got" - "what we expected"
        # r_t = immediate reward
        # γ * V(s_{t+1}) = discounted predicted future
        # V(s_t) = what we thought we'd get total
        delta = rewards[:, t] + gamma * next_value - values[:, t]
        
        # GAE accumulation: sum up errors with exponential decay
        # This smooths out noise and balances bias vs variance
        last_gae = delta + gamma * lam * last_gae
        advantages[:, t] = last_gae
    
    # Returns are what the value network should've predicted
    # They're the "ground truth" for training it
    returns = advantages + values
    
    return advantages, returns

# Let's see it in action
batch_size, seq_len = 2, 10

# Imagine we got these rewards (somewhat random, centered near 0)
rewards = torch.randn(batch_size, seq_len) * 0.5

# And our value network predicted these values
values = torch.randn(batch_size, seq_len)

advantages, returns = compute_gae(rewards, values)

print("Example rollout:")
print(f"  Rewards shape: {rewards.shape} — what actually happened")
print(f"  Values shape:  {values.shape}  — what we predicted")
print()
print(f"  Advantages shape: {advantages.shape} — surprise signal (positive = better than expected)")
print(f"  Returns shape:    {returns.shape}    — target for value network training")
print()
print("Sample advantages (first sequence, first 5 timesteps):")
print(advantages[0, :5])
print()
print("Positive = 'do more of this', negative = 'do less of this'")
Example rollout:
  Rewards shape: torch.Size([2, 10]) — what actually happened
  Values shape:  torch.Size([2, 10])  — what we predicted

  Advantages shape: torch.Size([2, 10]) — surprise signal (positive = better than expected)
  Returns shape:    torch.Size([2, 10])    — target for value network training

Sample advantages (first sequence, first 5 timesteps):
tensor([-1.8015, -3.0808, -1.3732, -3.9525, -1.7698])

Positive = 'do more of this', negative = 'do less of this'

Whitening: Keeping Things Stable

One more trick before we update the policy: advantage normalization (also called “whitening”).

Raw advantages can have wildly different scales. One batch might have advantages from -2 to +2. Another might be -100 to +100. This makes training unstable — your gradients are all over the place.

Solution? Normalize them. Make them have mean 0 and standard deviation 1.

This is like grading on a curve. It doesn’t matter if the test was easy or hard — you’re comparing students to each other, not to an absolute scale.

def whiten_advantages(advantages: torch.Tensor) -> torch.Tensor:
    """
    Normalize advantages to mean=0, std=1.
    
    Why? Training stability. We want consistent gradient scales across batches.
    
    This is "grading on a curve" — comparing actions to each other,
    not to an absolute scale.
    """
    mean = advantages.mean()
    std = advantages.std() + 1e-8  # Small epsilon prevents division by zero
    return (advantages - mean) / std

# Let's see the transformation
print("Before whitening:")
print(f"  Mean: {advantages.mean():.4f}")
print(f"  Std:  {advantages.std():.4f}")
print(f"  Min:  {advantages.min():.4f}")
print(f"  Max:  {advantages.max():.4f}")

whitened = whiten_advantages(advantages)

print()
print("After whitening:")
print(f"  Mean: {whitened.mean():.4f}  ← always ~0")
print(f"  Std:  {whitened.std():.4f}   ← always ~1")
print(f"  Min:  {whitened.min():.4f}")
print(f"  Max:  {whitened.max():.4f}")
print()
print("Now the scale is consistent across all batches!")
Before whitening:
  Mean: -1.2152
  Std:  1.3548
  Min:  -3.9525
  Max:  1.6824

After whitening:
  Mean: 0.0000  ← always ~0
  Std:  1.0000   ← always ~1
  Min:  -2.0205
  Max:  2.1388

Now the scale is consistent across all batches!

Phase 2: The PPO Update

Okay! We’ve generated rollouts. We’ve computed advantages. We’ve whitened them.

Now comes the actual learning: the PPO update.

Remember, PPO stands for Proximal Policy Optimization. The key word is proximal — we want to update the policy, but not too much. Stay close to the old policy. (That’s what the “clipping” does, which we’ll see in the code.)

In each update step:

def ppo_update_step(
    policy_model,
    value_network,
    rollout: RolloutBatch,
    optimizer,
    config
):
    """
    One PPO update step. This is the heart of RLHF training.
    
    We call this multiple times (ppo_epochs) for each rollout.
    Why? Because we have limited data, so we squeeze as much learning
    from it as we can. (But not too much — that's where clipping helps.)
    """
    
    # Step 1: Re-compute outputs with CURRENT policy
    # (The policy has changed since we did the rollout, so we need fresh numbers)
    current_logprobs = policy_model.get_logprobs(
        rollout.query_tensors,
        rollout.response_tensors
    )
    
    current_values = value_network(
        rollout.query_tensors,
        rollout.response_tensors
    )
    
    # Step 2: Compute the "ratio" 
    # This tells us: how much more likely is the current policy to take
    # the same action compared to the old policy?
    #
    # ratio = P_new(action) / P_old(action)
    # = exp(log P_new - log P_old)
    ratio = torch.exp(current_logprobs - rollout.logprobs)
    
    # Step 3: PPO clipped objective (this is THE magic)
    # We compute two versions and take the minimum
    
    # Unclipped: standard policy gradient
    # "How good was this action (advantage) * how much more likely we are to take it (ratio)"
    unclipped = ratio * rollout.advantages
    
    # Clipped: same thing, but we clip the ratio
    # Don't let it get too far from 1.0 (which means "same probability as old policy")
    # clip_ratio is typically 0.2, so we allow 0.8 to 1.2
    clipped = torch.clamp(
        ratio, 
        1 - config['clip_ratio'],  # Lower bound: 0.8
        1 + config['clip_ratio']   # Upper bound: 1.2
    ) * rollout.advantages
    
    # Take the minimum — this is conservative, prevents big updates
    # We want to improve, but carefully
    policy_loss = -torch.min(unclipped, clipped).mean()
    # (Negative because we want to maximize, but optimizers minimize)
    
    # Step 4: Value loss
    # Train the value network to predict returns better
    # Standard MSE loss
    value_loss = ((current_values - rollout.returns) ** 2).mean()
    
    # Step 5: KL penalty
    # Extra regularization: penalize drifting too far from reference model
    # This keeps us from forgetting the original language model's knowledge
    kl_penalty = (current_logprobs - rollout.ref_logprobs).mean()
    
    # Step 6: Combine everything
    total_loss = (
        policy_loss +                        # Main objective: improve policy
        config['vf_coef'] * value_loss +     # Train value network (vf_coef ≈ 0.5)
        config['kl_coef'] * kl_penalty       # Don't drift too far (kl_coef ≈ 0.1)
    )
    
    # Step 7: Standard PyTorch training step
    optimizer.zero_grad()
    total_loss.backward()
    
    # Clip gradients to prevent exploding gradients
    torch.nn.utils.clip_grad_norm_(
        policy_model.parameters(), 
        config['max_grad_norm']
    )
    
    optimizer.step()
    
    return {
        'policy_loss': policy_loss.item(),
        'value_loss': value_loss.item(),
        'kl': kl_penalty.item()
    }

# Show the structure
print("PPO Update in 7 steps:")
print()
print("  1. Re-compute log probs & values (policy changed since rollout)")
print("  2. Compute ratio of new/old policy probabilities")
print("  3. Compute clipped objective (the PPO magic)")
print("  4. Compute value loss (train value network)")
print("  5. Compute KL penalty (don't forget original LM)")
print("  6. Combine losses and backprop")
print("  7. Clip gradients and step optimizer")
print()
print("The clipping is what makes PPO 'proximal' — we update carefully.")
PPO Update in 7 steps:

  1. Re-compute log probs & values (policy changed since rollout)
  2. Compute ratio of new/old policy probabilities
  3. Compute clipped objective (the PPO magic)
  4. Compute value loss (train value network)
  5. Compute KL penalty (don't forget original LM)
  6. Combine losses and backprop
  7. Clip gradients and step optimizer

The clipping is what makes PPO 'proximal' — we update carefully.

The Complete Training Loop

Alright, let’s zoom out and see the full picture.

Here’s the rhythm of RLHF training, iteration by iteration:

def train_rlhf_loop(config):
    """
    The complete RLHF training loop.
    
    This is pseudocode showing the structure. In practice, you'd fill in
    the actual model calls.
    """
    
    for iteration in range(config['num_iterations']):
        
        # ===== PHASE 1: ROLLOUT =====
        # Generate a batch of responses and collect all the data we need
        
        # Sample some prompts from your dataset
        # prompts = sample_prompts(config['batch_size'])
        
        # Generate responses with the current policy
        # responses = policy_model.generate(prompts)
        
        # Score them with the reward model
        # rewards = reward_model(prompts, responses)
        
        # Get value predictions
        # values = value_network(prompts, responses)
        
        # Get log probabilities from policy and reference model
        # policy_logprobs = policy_model.get_logprobs(prompts, responses)
        # ref_logprobs = reference_model.get_logprobs(prompts, responses)
        
        # Compute advantages using GAE
        # advantages, returns = compute_gae(rewards, values, gamma=0.99, lam=0.95)
        
        # Normalize advantages
        # advantages = whiten_advantages(advantages)
        
        # Package everything into a rollout batch
        # rollout = RolloutBatch(
        #     query_tensors=prompts,
        #     response_tensors=responses,
        #     logprobs=policy_logprobs,
        #     ref_logprobs=ref_logprobs,
        #     values=values,
        #     rewards=rewards,
        #     advantages=advantages,
        #     returns=returns
        # )
        
        # ===== PHASE 2: PPO UPDATE =====
        # Learn from the rollout, multiple passes over the same data
        
        # for epoch in range(config['ppo_epochs']):
        #     metrics = ppo_update_step(
        #         policy_model,
        #         value_network,
        #         rollout,
        #         optimizer,
        #         config
        #     )
        
        # ===== LOGGING & CHECKPOINTING =====
        # Track progress, save models periodically
        
        # if iteration % 10 == 0:
        #     log_metrics(metrics)
        #     print(f"Iteration {iteration}: KL={metrics['kl']:.4f}, "
        #           f"Policy Loss={metrics['policy_loss']:.4f}")
        
        # if iteration % 100 == 0:
        #     save_checkpoint(policy_model, value_network, iteration)
        
        pass
    
    # return policy_model

# Show the structure
print("═" * 60)
print("RLHF Training Loop — The Big Picture")
print("═" * 60)
print()
print("For each iteration:")
print()
print("  PHASE 1: ROLLOUT (Generate & Score)")
print("    ├─ Sample prompts from dataset")
print("    ├─ Generate responses with current policy")
print("    ├─ Score with reward model")
print("    ├─ Get value predictions")
print("    ├─ Compute advantages (GAE)")
print("    └─ Package into RolloutBatch")
print()
print("  PHASE 2: UPDATE (Learn)")
print("    └─ Run PPO updates (4 epochs typically)")
print("        ├─ Re-compute log probs with current policy")
print("        ├─ Compute clipped policy loss")
print("        ├─ Compute value loss")
print("        ├─ Add KL penalty")
print("        └─ Backprop and step")
print()
print("  Repeat until convergence (or you run out of patience/budget)")
print()
print("═" * 60)

# Run with example config
config = {
    'num_iterations': 1000,
    'ppo_epochs': 4,
    'batch_size': 8,
    'clip_ratio': 0.2,
    'vf_coef': 0.5,
    'kl_coef': 0.1,
    'max_grad_norm': 1.0
}

train_rlhf_loop(config)
════════════════════════════════════════════════════════════
RLHF Training Loop — The Big Picture
════════════════════════════════════════════════════════════

For each iteration:

  PHASE 1: ROLLOUT (Generate & Score)
    ├─ Sample prompts from dataset
    ├─ Generate responses with current policy
    ├─ Score with reward model
    ├─ Get value predictions
    ├─ Compute advantages (GAE)
    └─ Package into RolloutBatch

  PHASE 2: UPDATE (Learn)
    └─ Run PPO updates (4 epochs typically)
        ├─ Re-compute log probs with current policy
        ├─ Compute clipped policy loss
        ├─ Compute value loss
        ├─ Add KL penalty
        └─ Backprop and step

  Repeat until convergence (or you run out of patience/budget)

════════════════════════════════════════════════════════════

Key Hyperparameters (And What They Actually Do)

Here are the knobs you can turn, and what happens when you turn them:

ParameterTypical ValueWhat It DoesWhen To Change
gamma0.99Discount factor — How much you care about future rewards vs immediate ones. Higher = “patient”, lower = “myopic”.Keep at 0.99 for language. You care about the whole response, not just the next token.
gae_lambda0.95GAE smoothing — Bias-variance tradeoff. Higher = trust multi-step returns (less bias, more variance). Lower = trust single-step (more bias, less variance).0.95 is the sweet spot. Don’t mess with it unless you’re doing research.
ppo_epochs4How many times to learn from each rollout. Higher = squeeze more from limited data, but risk overfitting.4 is standard. Increase if data is expensive (it is). Decrease if you see KL divergence exploding.
clip_ratio0.2How far the policy can change in one update. The “proximal” in PPO. Higher = bigger steps (faster, riskier). Lower = smaller steps (slower, safer).0.2 is battle-tested. Don’t change unless you know what you’re doing.
batch_size4-16Prompts per iteration. Higher = more stable, slower. Lower = faster, noisier.Limited by GPU memory. RLHF is memory-hungry (you’re running 4 models at once).
vf_coef0.5Value loss weight. How much you care about training the value network vs the policy.0.5 means “value loss matters half as much as policy loss.” Usually fine.
kl_coef0.01-0.1KL penalty weight. How much you penalize drifting from the reference model. Higher = stay closer to original LM (safer but less freedom).Start low (0.01), increase if you see mode collapse or gibberish.
max_grad_norm1.0Gradient clipping threshold. Prevents exploding gradients.1.0 works. Lower if training is unstable.

The defaults are pretty good. RLHF has been tuned by a lot of very smart people spending a lot of GPU-hours. Trust the defaults until you have a specific reason not to.

Wrapping Up

You now understand the heartbeat of RLHF:

  1. Generate responses (rollout)

  2. Evaluate them (rewards)

  3. Compute advantages (GAE) — “better or worse than expected?”

  4. Update the policy (PPO) — carefully, with clipping

That’s the loop. Do it a few thousand times, and you’ve got yourself a fine-tuned language model that optimizes for whatever your reward model says is good.

Pretty wild, right? You’re literally teaching the model through trial and error, just like how you’d train a dog. Except the dog is a billion-parameter transformer and the treats are scalar reward values.

Next up: we’ll talk about reference models — how to keep one frozen copy of your policy to prevent the training from going off the rails. (It’s simpler than you might expect.)