You’ve seen the pieces — policy models, reward models, value networks, PPO. Now let’s watch them dance together.
This is where RLHF gets real. We’re going to break down exactly what happens during training, step by step, no hand-waving. By the end, you’ll understand the heartbeat of RLHF: generate, evaluate, improve, repeat.
The Two-Phase Dance¶
RLHF training has a rhythm. Each iteration does two things:
Phase 1: Rollout — Your policy generates responses. You score them. You remember what happened.
Phase 2: Update — You use that data to make your policy better. Then you do it again.
Think of it like learning to cook. First you make a dish (rollout). Then you taste it, figure out what worked and what didn’t, and adjust your technique (update). Rinse and repeat until you’re a chef.
The key insight? You don’t learn while you’re generating. You generate first, then learn from what you generated. This separation is what makes PPO stable.
What’s a Rollout, Anyway?¶
Here’s a word you’ll hear constantly in RL: rollout.
A rollout is just... running your policy and recording what happens. That’s it.
In RLHF terms:
You give your model some prompts
It generates responses (one token at a time)
You score those responses with your reward model
You ask your value network what it thinks
You save everything to memory
Think of it like recording a practice session. You’re not evaluating yourself mid-sentence — you finish your full response, then you look at what you did. The recording (the rollout) lets you study your performance afterward.
In each rollout we store:
import torch
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class RolloutBatch:
"""Everything we need to remember from generating responses."""
# What we started with
query_tensors: torch.Tensor # The prompts we gave the model
response_tensors: torch.Tensor # What the model generated
# What the models thought at generation time
logprobs: torch.Tensor # How confident was our policy?
ref_logprobs: torch.Tensor # How confident was the reference model?
values: torch.Tensor # What did the value network predict?
# What we computed afterward
rewards: torch.Tensor # Reward model scores (the "grades")
advantages: torch.Tensor # "How much better than expected?" (we'll explain this!)
returns: torch.Tensor # Target values for training the value network
# Let's visualize what a rollout looks like
print("A RolloutBatch is like a student's exam:")
print()
print(" query_tensors → The questions")
print(" response_tensors → The student's answers")
print(" logprobs → Student's confidence in each answer")
print(" ref_logprobs → What a baseline student would've done")
print(" values → Student's prediction of their grade")
print(" rewards → The actual grade from the teacher")
print(" advantages → Did they do better or worse than expected?")
print(" returns → What the grade *should've* been")
print()
print("We store all of this so we can learn from it later.")A RolloutBatch is like a student's exam:
query_tensors → The questions
response_tensors → The student's answers
logprobs → Student's confidence in each answer
ref_logprobs → What a baseline student would've done
values → Student's prediction of their grade
rewards → The actual grade from the teacher
advantages → Did they do better or worse than expected?
returns → What the grade *should've* been
We store all of this so we can learn from it later.
Advantages: The Key Insight¶
Now we get to the really clever bit. Let’s talk about advantages.
Here’s the core question in reinforcement learning: which actions were actually good?
You might think “just use the rewards!” But that’s too naive. Imagine you’re playing basketball, and your team wins 100-98. Every action you took gets associated with “winning” — but some of your shots were terrible and you just got lucky.
What you really want to know is: which actions were better than expected?
That’s an advantage. It’s not the absolute reward. It’s the surprise.
If your value network predicted you’d get a reward of 5, and you got 7, your advantage is +2. You did better than expected! Reinforce that behavior.
If it predicted 8 and you got 7, your advantage is -1. You did worse than expected. Maybe don’t do that again.
This is where GAE comes in — Generalized Advantage Estimation. (Fancy name, but the intuition is simple: it’s a smart way to compute these “better than expected” signals.)
The formula looks scary, but let’s break it down:
def compute_gae(
rewards: torch.Tensor,
values: torch.Tensor,
gamma: float = 0.99,
lam: float = 0.95
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute Generalized Advantage Estimation (GAE).
This answers: "How much better/worse was each action compared to what we expected?"
The math (don't panic):
A_t = Σ (γλ)^l * δ_{t+l}
Where δ_t is the "TD error" (Temporal Difference error):
δ_t = r_t + γV(s_{t+1}) - V(s_t)
In English:
δ_t = "actual reward" + "discounted future value" - "predicted value"
δ_t = "what happened" - "what we expected"
Then we sum up these errors over time, with exponential decay (γλ).
Args:
rewards: What actually happened, shape (batch, seq_len)
values: What we predicted would happen, shape (batch, seq_len)
gamma: Discount factor (how much we care about future rewards)
Higher = "future matters", lower = "only now matters"
lam: GAE lambda (bias-variance knob, 0.95 is standard)
Higher = trust multi-step returns, lower = trust single-step
Returns:
advantages: "Better or worse than expected?" for each timestep
returns: What the value *should've* been (for training value network)
"""
batch_size, seq_len = rewards.shape
advantages = torch.zeros_like(rewards)
last_gae = 0
# We go backwards through time (from end to beginning)
# Why? Because we need to know the future to compute "expected future value"
for t in reversed(range(seq_len)):
if t == seq_len - 1:
next_value = 0 # End of sequence, no future
else:
next_value = values[:, t + 1]
# TD error: "what we got" - "what we expected"
# r_t = immediate reward
# γ * V(s_{t+1}) = discounted predicted future
# V(s_t) = what we thought we'd get total
delta = rewards[:, t] + gamma * next_value - values[:, t]
# GAE accumulation: sum up errors with exponential decay
# This smooths out noise and balances bias vs variance
last_gae = delta + gamma * lam * last_gae
advantages[:, t] = last_gae
# Returns are what the value network should've predicted
# They're the "ground truth" for training it
returns = advantages + values
return advantages, returns
# Let's see it in action
batch_size, seq_len = 2, 10
# Imagine we got these rewards (somewhat random, centered near 0)
rewards = torch.randn(batch_size, seq_len) * 0.5
# And our value network predicted these values
values = torch.randn(batch_size, seq_len)
advantages, returns = compute_gae(rewards, values)
print("Example rollout:")
print(f" Rewards shape: {rewards.shape} — what actually happened")
print(f" Values shape: {values.shape} — what we predicted")
print()
print(f" Advantages shape: {advantages.shape} — surprise signal (positive = better than expected)")
print(f" Returns shape: {returns.shape} — target for value network training")
print()
print("Sample advantages (first sequence, first 5 timesteps):")
print(advantages[0, :5])
print()
print("Positive = 'do more of this', negative = 'do less of this'")Example rollout:
Rewards shape: torch.Size([2, 10]) — what actually happened
Values shape: torch.Size([2, 10]) — what we predicted
Advantages shape: torch.Size([2, 10]) — surprise signal (positive = better than expected)
Returns shape: torch.Size([2, 10]) — target for value network training
Sample advantages (first sequence, first 5 timesteps):
tensor([-1.8015, -3.0808, -1.3732, -3.9525, -1.7698])
Positive = 'do more of this', negative = 'do less of this'
Whitening: Keeping Things Stable¶
One more trick before we update the policy: advantage normalization (also called “whitening”).
Raw advantages can have wildly different scales. One batch might have advantages from -2 to +2. Another might be -100 to +100. This makes training unstable — your gradients are all over the place.
Solution? Normalize them. Make them have mean 0 and standard deviation 1.
This is like grading on a curve. It doesn’t matter if the test was easy or hard — you’re comparing students to each other, not to an absolute scale.
def whiten_advantages(advantages: torch.Tensor) -> torch.Tensor:
"""
Normalize advantages to mean=0, std=1.
Why? Training stability. We want consistent gradient scales across batches.
This is "grading on a curve" — comparing actions to each other,
not to an absolute scale.
"""
mean = advantages.mean()
std = advantages.std() + 1e-8 # Small epsilon prevents division by zero
return (advantages - mean) / std
# Let's see the transformation
print("Before whitening:")
print(f" Mean: {advantages.mean():.4f}")
print(f" Std: {advantages.std():.4f}")
print(f" Min: {advantages.min():.4f}")
print(f" Max: {advantages.max():.4f}")
whitened = whiten_advantages(advantages)
print()
print("After whitening:")
print(f" Mean: {whitened.mean():.4f} ← always ~0")
print(f" Std: {whitened.std():.4f} ← always ~1")
print(f" Min: {whitened.min():.4f}")
print(f" Max: {whitened.max():.4f}")
print()
print("Now the scale is consistent across all batches!")Before whitening:
Mean: -1.2152
Std: 1.3548
Min: -3.9525
Max: 1.6824
After whitening:
Mean: 0.0000 ← always ~0
Std: 1.0000 ← always ~1
Min: -2.0205
Max: 2.1388
Now the scale is consistent across all batches!
Phase 2: The PPO Update¶
Okay! We’ve generated rollouts. We’ve computed advantages. We’ve whitened them.
Now comes the actual learning: the PPO update.
Remember, PPO stands for Proximal Policy Optimization. The key word is proximal — we want to update the policy, but not too much. Stay close to the old policy. (That’s what the “clipping” does, which we’ll see in the code.)
In each update step:
def ppo_update_step(
policy_model,
value_network,
rollout: RolloutBatch,
optimizer,
config
):
"""
One PPO update step. This is the heart of RLHF training.
We call this multiple times (ppo_epochs) for each rollout.
Why? Because we have limited data, so we squeeze as much learning
from it as we can. (But not too much — that's where clipping helps.)
"""
# Step 1: Re-compute outputs with CURRENT policy
# (The policy has changed since we did the rollout, so we need fresh numbers)
current_logprobs = policy_model.get_logprobs(
rollout.query_tensors,
rollout.response_tensors
)
current_values = value_network(
rollout.query_tensors,
rollout.response_tensors
)
# Step 2: Compute the "ratio"
# This tells us: how much more likely is the current policy to take
# the same action compared to the old policy?
#
# ratio = P_new(action) / P_old(action)
# = exp(log P_new - log P_old)
ratio = torch.exp(current_logprobs - rollout.logprobs)
# Step 3: PPO clipped objective (this is THE magic)
# We compute two versions and take the minimum
# Unclipped: standard policy gradient
# "How good was this action (advantage) * how much more likely we are to take it (ratio)"
unclipped = ratio * rollout.advantages
# Clipped: same thing, but we clip the ratio
# Don't let it get too far from 1.0 (which means "same probability as old policy")
# clip_ratio is typically 0.2, so we allow 0.8 to 1.2
clipped = torch.clamp(
ratio,
1 - config['clip_ratio'], # Lower bound: 0.8
1 + config['clip_ratio'] # Upper bound: 1.2
) * rollout.advantages
# Take the minimum — this is conservative, prevents big updates
# We want to improve, but carefully
policy_loss = -torch.min(unclipped, clipped).mean()
# (Negative because we want to maximize, but optimizers minimize)
# Step 4: Value loss
# Train the value network to predict returns better
# Standard MSE loss
value_loss = ((current_values - rollout.returns) ** 2).mean()
# Step 5: KL penalty
# Extra regularization: penalize drifting too far from reference model
# This keeps us from forgetting the original language model's knowledge
kl_penalty = (current_logprobs - rollout.ref_logprobs).mean()
# Step 6: Combine everything
total_loss = (
policy_loss + # Main objective: improve policy
config['vf_coef'] * value_loss + # Train value network (vf_coef ≈ 0.5)
config['kl_coef'] * kl_penalty # Don't drift too far (kl_coef ≈ 0.1)
)
# Step 7: Standard PyTorch training step
optimizer.zero_grad()
total_loss.backward()
# Clip gradients to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(
policy_model.parameters(),
config['max_grad_norm']
)
optimizer.step()
return {
'policy_loss': policy_loss.item(),
'value_loss': value_loss.item(),
'kl': kl_penalty.item()
}
# Show the structure
print("PPO Update in 7 steps:")
print()
print(" 1. Re-compute log probs & values (policy changed since rollout)")
print(" 2. Compute ratio of new/old policy probabilities")
print(" 3. Compute clipped objective (the PPO magic)")
print(" 4. Compute value loss (train value network)")
print(" 5. Compute KL penalty (don't forget original LM)")
print(" 6. Combine losses and backprop")
print(" 7. Clip gradients and step optimizer")
print()
print("The clipping is what makes PPO 'proximal' — we update carefully.")PPO Update in 7 steps:
1. Re-compute log probs & values (policy changed since rollout)
2. Compute ratio of new/old policy probabilities
3. Compute clipped objective (the PPO magic)
4. Compute value loss (train value network)
5. Compute KL penalty (don't forget original LM)
6. Combine losses and backprop
7. Clip gradients and step optimizer
The clipping is what makes PPO 'proximal' — we update carefully.
The Complete Training Loop¶
Alright, let’s zoom out and see the full picture.
Here’s the rhythm of RLHF training, iteration by iteration:
def train_rlhf_loop(config):
"""
The complete RLHF training loop.
This is pseudocode showing the structure. In practice, you'd fill in
the actual model calls.
"""
for iteration in range(config['num_iterations']):
# ===== PHASE 1: ROLLOUT =====
# Generate a batch of responses and collect all the data we need
# Sample some prompts from your dataset
# prompts = sample_prompts(config['batch_size'])
# Generate responses with the current policy
# responses = policy_model.generate(prompts)
# Score them with the reward model
# rewards = reward_model(prompts, responses)
# Get value predictions
# values = value_network(prompts, responses)
# Get log probabilities from policy and reference model
# policy_logprobs = policy_model.get_logprobs(prompts, responses)
# ref_logprobs = reference_model.get_logprobs(prompts, responses)
# Compute advantages using GAE
# advantages, returns = compute_gae(rewards, values, gamma=0.99, lam=0.95)
# Normalize advantages
# advantages = whiten_advantages(advantages)
# Package everything into a rollout batch
# rollout = RolloutBatch(
# query_tensors=prompts,
# response_tensors=responses,
# logprobs=policy_logprobs,
# ref_logprobs=ref_logprobs,
# values=values,
# rewards=rewards,
# advantages=advantages,
# returns=returns
# )
# ===== PHASE 2: PPO UPDATE =====
# Learn from the rollout, multiple passes over the same data
# for epoch in range(config['ppo_epochs']):
# metrics = ppo_update_step(
# policy_model,
# value_network,
# rollout,
# optimizer,
# config
# )
# ===== LOGGING & CHECKPOINTING =====
# Track progress, save models periodically
# if iteration % 10 == 0:
# log_metrics(metrics)
# print(f"Iteration {iteration}: KL={metrics['kl']:.4f}, "
# f"Policy Loss={metrics['policy_loss']:.4f}")
# if iteration % 100 == 0:
# save_checkpoint(policy_model, value_network, iteration)
pass
# return policy_model
# Show the structure
print("═" * 60)
print("RLHF Training Loop — The Big Picture")
print("═" * 60)
print()
print("For each iteration:")
print()
print(" PHASE 1: ROLLOUT (Generate & Score)")
print(" ├─ Sample prompts from dataset")
print(" ├─ Generate responses with current policy")
print(" ├─ Score with reward model")
print(" ├─ Get value predictions")
print(" ├─ Compute advantages (GAE)")
print(" └─ Package into RolloutBatch")
print()
print(" PHASE 2: UPDATE (Learn)")
print(" └─ Run PPO updates (4 epochs typically)")
print(" ├─ Re-compute log probs with current policy")
print(" ├─ Compute clipped policy loss")
print(" ├─ Compute value loss")
print(" ├─ Add KL penalty")
print(" └─ Backprop and step")
print()
print(" Repeat until convergence (or you run out of patience/budget)")
print()
print("═" * 60)
# Run with example config
config = {
'num_iterations': 1000,
'ppo_epochs': 4,
'batch_size': 8,
'clip_ratio': 0.2,
'vf_coef': 0.5,
'kl_coef': 0.1,
'max_grad_norm': 1.0
}
train_rlhf_loop(config)════════════════════════════════════════════════════════════
RLHF Training Loop — The Big Picture
════════════════════════════════════════════════════════════
For each iteration:
PHASE 1: ROLLOUT (Generate & Score)
├─ Sample prompts from dataset
├─ Generate responses with current policy
├─ Score with reward model
├─ Get value predictions
├─ Compute advantages (GAE)
└─ Package into RolloutBatch
PHASE 2: UPDATE (Learn)
└─ Run PPO updates (4 epochs typically)
├─ Re-compute log probs with current policy
├─ Compute clipped policy loss
├─ Compute value loss
├─ Add KL penalty
└─ Backprop and step
Repeat until convergence (or you run out of patience/budget)
════════════════════════════════════════════════════════════
Key Hyperparameters (And What They Actually Do)¶
Here are the knobs you can turn, and what happens when you turn them:
| Parameter | Typical Value | What It Does | When To Change |
|---|---|---|---|
gamma | 0.99 | Discount factor — How much you care about future rewards vs immediate ones. Higher = “patient”, lower = “myopic”. | Keep at 0.99 for language. You care about the whole response, not just the next token. |
gae_lambda | 0.95 | GAE smoothing — Bias-variance tradeoff. Higher = trust multi-step returns (less bias, more variance). Lower = trust single-step (more bias, less variance). | 0.95 is the sweet spot. Don’t mess with it unless you’re doing research. |
ppo_epochs | 4 | How many times to learn from each rollout. Higher = squeeze more from limited data, but risk overfitting. | 4 is standard. Increase if data is expensive (it is). Decrease if you see KL divergence exploding. |
clip_ratio | 0.2 | How far the policy can change in one update. The “proximal” in PPO. Higher = bigger steps (faster, riskier). Lower = smaller steps (slower, safer). | 0.2 is battle-tested. Don’t change unless you know what you’re doing. |
batch_size | 4-16 | Prompts per iteration. Higher = more stable, slower. Lower = faster, noisier. | Limited by GPU memory. RLHF is memory-hungry (you’re running 4 models at once). |
vf_coef | 0.5 | Value loss weight. How much you care about training the value network vs the policy. | 0.5 means “value loss matters half as much as policy loss.” Usually fine. |
kl_coef | 0.01-0.1 | KL penalty weight. How much you penalize drifting from the reference model. Higher = stay closer to original LM (safer but less freedom). | Start low (0.01), increase if you see mode collapse or gibberish. |
max_grad_norm | 1.0 | Gradient clipping threshold. Prevents exploding gradients. | 1.0 works. Lower if training is unstable. |
The defaults are pretty good. RLHF has been tuned by a lot of very smart people spending a lot of GPU-hours. Trust the defaults until you have a specific reason not to.
Wrapping Up¶
You now understand the heartbeat of RLHF:
Generate responses (rollout)
Evaluate them (rewards)
Compute advantages (GAE) — “better or worse than expected?”
Update the policy (PPO) — carefully, with clipping
That’s the loop. Do it a few thousand times, and you’ve got yourself a fine-tuned language model that optimizes for whatever your reward model says is good.
Pretty wild, right? You’re literally teaching the model through trial and error, just like how you’d train a dog. Except the dog is a billion-parameter transformer and the treats are scalar reward values.
Next up: we’ll talk about reference models — how to keep one frozen copy of your policy to prevent the training from going off the rails. (It’s simpler than you might expect.)