We’ve trained a reward model. We’ve got our policy ready to learn.
But there’s one more piece to this RLHF puzzle. And it’s weirdly important.
We need to keep a copy of our model exactly as it is right now — before we start letting the reward model push it around. A frozen snapshot. An anchor point.
This is the reference model.
(And yes, it does mean we’ll have two copies of the same giant neural network sitting in memory. We’ll deal with that headache in a minute.)
The Drift Problem (Or: Why Models Go Crazy)¶
What happens without a reference model.
You start RLHF training. Your policy generates some text. The reward model scores it. The policy adjusts to get higher scores.
Sounds great, right?
Except... the policy starts to drift.
At first it’s subtle. Maybe it discovers that the reward model really likes certain phrases. So it uses them more often. Then it discovers that repeating the same word over and over scores well (because the reward model has some weird quirk). So it does that. A lot.
Eventually, you end up with a model that outputs complete gibberish but somehow scores incredibly high rewards.
This is called reward hacking. The policy has learned to exploit the reward model instead of actually getting better at the task.
Think of it like this: You’re trying to teach a dog to fetch. But instead of learning to bring back the ball, the dog realizes that if it just sits on your lap and licks your face, you’ll give it treats anyway. Technically high reward. Completely missing the point.
Enter the Reference Model¶
The reference model is our solution. It’s a frozen copy of the policy before RLHF training starts.
During training, we don’t just maximize reward. We maximize reward while staying close to the reference model.
Here’s the actual objective function we optimize:
reward(response) - β × KL(policy || reference)
Let me break that down in English:
reward(response): How much the reward model likes our output
β (beta): A hyperparameter controlling how much we care about staying close (typically 0.01 to 0.1)
KL(policy || reference): The KL divergence — a measure of how different the policy’s probability distribution is from the reference’s
The KL divergence is the key. It penalizes the policy for assigning very different probabilities to words compared to the reference. If the policy starts drifting into weird territory, the KL penalty pulls it back.
It’s like training that dog on a leash. Sure, explore a bit. Get creative. But don’t wander so far that you forget what a normal dog is supposed to do.
SFT Model (our starting point)
│
├──> Policy Model (trainable, learns from rewards)
│
└──> Reference Model (frozen, stays exactly as it was)What Can Go Wrong Without a Reference?¶
Let’s be specific about the failure modes:
1. Reward Hacking The policy finds shortcuts to high reward that completely miss the point. Like discovering the reward model gives high scores to responses that end with “I hope this helps!” — so the policy starts ending every response with that phrase, even when it makes no sense.
2. Mode Collapse The policy collapses to generating the same response over and over. Why? Because it found one response that scores well, and without the KL penalty, there’s no incentive to explore. You end up with a model that outputs “That’s a great question! I’m happy to help.” to literally everything.
3. Language Degradation The policy forgets how to speak coherently. It drifts so far from natural language that it’s generating statistically high-reward token sequences that look like alphabet soup to humans.
The reference model prevents all three. It keeps the policy grounded in the language patterns it learned during pretraining and SFT.
Pretty important, yeah?
import torch
import copy
from transformers import AutoModelForCausalLM
def create_reference_model(policy_model):
"""
Create a frozen reference model from the policy.
This is literally just a deep copy with all the parameters frozen.
We're making a snapshot of the model at this exact moment in time.
"""
# Deep copy the entire model (this duplicates all weights in memory)
reference_model = copy.deepcopy(policy_model)
# Freeze every single parameter - no gradients, no updates
for param in reference_model.parameters():
param.requires_grad = False
# Set to evaluation mode (disables dropout, etc.)
reference_model.eval()
return reference_model
# Let's create a reference model from GPT-2
print("Loading GPT-2 as our policy model...")
policy_model = AutoModelForCausalLM.from_pretrained("gpt2")
print("Creating frozen reference model...")
reference_model = create_reference_model(policy_model)
# Verify that the policy can be trained but the reference cannot
policy_trainable = sum(p.numel() for p in policy_model.parameters() if p.requires_grad)
ref_trainable = sum(p.numel() for p in reference_model.parameters() if p.requires_grad)
print(f"\nPolicy model: {policy_trainable:,} trainable parameters")
print(f"Reference model: {ref_trainable:,} trainable parameters")
print(f"\nThe reference is completely frozen - exactly what we want!")Loading GPT-2 as our policy model...
Creating frozen reference model...
Policy model: 124,439,808 trainable parameters
Reference model: 0 trainable parameters
The reference is completely frozen - exactly what we want!
The Memory Problem (Sigh.)¶
Okay, so we just made a complete copy of our model.
If you’re training GPT-2 (124M parameters), that’s annoying but manageable. If you’re training Llama 2 7B... well, now you’ve got 14B parameters sitting in memory. If you’re training Llama 2 70B, you might want to start crying now.
Two identical massive models. Double the memory usage. All so we can compute a penalty term.
There are a few ways to deal with this headache:
Option 1: Half Precision Reference Keep the reference in FP16 (half precision) instead of FP32. Cuts memory in half. The reference doesn’t need to be super precise — we’re just using it to compute KL divergence.
Option 2: CPU Reference Move the reference to CPU. Frees up GPU memory for the policy. But now every time you need to compute KL, you’re shuffling data between CPU and GPU. Slower, but sometimes necessary.
Option 3: Periodic KL Computation Only compute the KL penalty every N steps instead of every step. It’s an approximation, but it can work if you’re willing to accept slightly less stable training.
Option 4: No Reference (DPO) Just... don’t use a reference model at all. This is actually what DPO (Direct Preference Optimization) does. It’s a completely different algorithm that doesn’t need reward models or reference models. We’ll look at that in the next notebook.
Let’s implement options 1 and 2:
# Option 1: Keep reference in half precision
def create_reference_model_fp16(policy_model):
"""
Create reference model in half precision to save memory.
FP16 (float16) uses 16 bits per parameter instead of 32.
Cuts memory usage in half. The reference doesn't need full precision
since we're just using it for KL divergence calculations.
"""
reference_model = copy.deepcopy(policy_model)
reference_model = reference_model.half() # Convert to FP16
for param in reference_model.parameters():
param.requires_grad = False
reference_model.eval()
return reference_model
# Option 2: Move reference to CPU (slower but saves GPU memory)
def create_reference_model_cpu(policy_model):
"""
Create reference model on CPU to save GPU memory.
This moves the reference entirely to CPU. Frees up GPU memory
for the policy, but you'll pay a speed penalty when computing
KL divergence (have to move data between CPU and GPU).
"""
reference_model = copy.deepcopy(policy_model)
reference_model = reference_model.cpu()
for param in reference_model.parameters():
param.requires_grad = False
reference_model.eval()
return reference_model
# Let's try the FP16 version
print("Creating FP16 reference model...")
reference_model_fp16 = create_reference_model_fp16(policy_model)
# Check memory savings (approximate)
fp32_params = sum(p.numel() for p in reference_model.parameters())
fp16_params = sum(p.numel() for p in reference_model_fp16.parameters())
fp32_memory_mb = (fp32_params * 4) / (1024 * 1024) # 4 bytes per FP32
fp16_memory_mb = (fp16_params * 2) / (1024 * 1024) # 2 bytes per FP16
print(f"\nMemory usage comparison:")
print(f" FP32 reference: ~{fp32_memory_mb:.1f} MB")
print(f" FP16 reference: ~{fp16_memory_mb:.1f} MB")
print(f" Savings: ~{fp32_memory_mb - fp16_memory_mb:.1f} MB ({((fp32_memory_mb - fp16_memory_mb) / fp32_memory_mb * 100):.0f}% reduction)")
print(f"\nFor a 7B parameter model, that's saving ~14GB of memory. Not bad!")Creating FP16 reference model...
Memory usage comparison:
FP32 reference: ~474.7 MB
FP16 reference: ~237.4 MB
Savings: ~237.4 MB (50% reduction)
For a 7B parameter model, that's saving ~14GB of memory. Not bad!
Computing KL Divergence: The Actual Math¶
Alright, time to get into the weeds a bit.
We need to compute the KL divergence between the policy and reference. But what does that actually mean in code?
Remember: both models are probability distributions over the next token. Given the same input, they both output a probability for every possible next token.
The KL divergence measures how different those probability distributions are.
The formula is:
KL(policy || reference) = Σ p(token) × log(p(token) / r(token))
In English:
p(token): Probability the policy assigns to this token
r(token): Probability the reference assigns to this token
log(p/r): The log ratio (becomes 0 when they’re equal, positive when policy prefers this token more, negative when less)
Σ: Sum over all tokens in the vocabulary
The trick: we don’t actually compute this sum over the entire vocabulary (that would be 50,000+ tokens for GPT-2).
Instead, we compute it only for the tokens that were actually generated. This is called the per-token KL divergence and it’s what you’ll see in practice.
For each token in the generated response:
Get the policy’s log probability for that token
Get the reference’s log probability for that token
Take the difference
Let’s implement it:
import torch.nn.functional as F
def get_log_probs(model, input_ids, attention_mask):
"""
Get log probabilities for tokens under a model.
This runs the model forward and extracts the log probability
it assigned to each token that actually appeared in the sequence.
"""
with torch.no_grad(): # No gradients for the reference
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits # Raw scores for each token in vocab
# Shift for next-token prediction
# The model predicts token i+1 from tokens 0...i
# So logits[:, 0] predicts token 1, logits[:, 1] predicts token 2, etc.
shift_logits = logits[:, :-1, :] # Drop last position (no next token)
shift_labels = input_ids[:, 1:] # Drop first token (not predicted)
# Convert logits to log probabilities via log softmax
log_probs = F.log_softmax(shift_logits, dim=-1)
# Extract just the log probs for the actual tokens that appeared
# This gathers log_probs[i, shift_labels[i]] for each position i
token_log_probs = torch.gather(
log_probs,
dim=-1,
index=shift_labels.unsqueeze(-1)
).squeeze(-1)
return token_log_probs
# Let's test this with a real sentence
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
text = "Hello, how are you doing today?"
inputs = tokenizer(text, return_tensors="pt")
print(f"Input text: '{text}'")
print(f"Tokens: {inputs['input_ids'].tolist()[0]}")
# Get log probs from both models
policy_logprobs = get_log_probs(policy_model, inputs['input_ids'], inputs['attention_mask'])
ref_logprobs = get_log_probs(reference_model, inputs['input_ids'], inputs['attention_mask'])
print(f"\nLog probability shapes: {policy_logprobs.shape}")
print(f"(One log prob for each of the {policy_logprobs.shape[1]} tokens we're predicting)")
# Compute per-token KL divergence
kl_per_token = policy_logprobs - ref_logprobs
kl_mean = kl_per_token.mean()
print(f"\nPer-token KL divergences: {kl_per_token[0].tolist()}")
print(f"Mean KL divergence: {kl_mean.item():.4f}")
print(f"\nSince we haven't trained the policy yet, it's identical to the reference.")
print(f"That's why KL divergence is basically zero. After training, this would be > 0.")Input text: 'Hello, how are you doing today?'
Tokens: [15496, 11, 703, 389, 345, 1804, 1909, 30]
Log probability shapes: torch.Size([1, 7])
(One log prob for each of the 7 tokens we're predicting)
Per-token KL divergences: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
Mean KL divergence: 0.0000
Since we haven't trained the policy yet, it's identical to the reference.
That's why KL divergence is basically zero. After training, this would be > 0.
Sanity Check: Is the Reference Actually Frozen?¶
Before we finish, let’s verify our reference model is actually frozen. You’d be surprised how easy it is to mess this up (and then spend hours debugging why your training is behaving weirdly).
We want to check two things:
No parameters require gradients — the reference should not update during training
Weights stay constant — after policy training, policy and reference should be different
Let’s check:
def verify_reference_frozen(policy_model, reference_model):
"""
Verify that reference model is properly frozen.
Returns True if everything looks good, False if something's wrong.
"""
# Check 1: No parameters should require gradients
ref_requires_grad = any(p.requires_grad for p in reference_model.parameters())
# Check 2: Weights should be identical to policy initially
# (They'll diverge after policy training)
first_policy_param = next(policy_model.parameters())
first_ref_param = next(reference_model.parameters())
weights_equal = torch.allclose(first_policy_param, first_ref_param, atol=1e-6)
print("Reference Model Verification:")
print(f" Any parameter requires grad? {ref_requires_grad}")
print(f" ✓ Should be False (frozen)" if not ref_requires_grad else " ✗ PROBLEM: Should be False!")
print(f"\n Weights equal to policy? {weights_equal}")
print(f" ✓ Should be True initially (same starting point)" if weights_equal else " ✓ Different after training")
print(f"\n In eval mode? {not reference_model.training}")
print(f" ✓ Should be True (no dropout, etc.)" if not reference_model.training else " ✗ PROBLEM: Should be True!")
all_good = (not ref_requires_grad) and (not reference_model.training)
if all_good:
print("\n✓ Reference model is properly frozen!")
else:
print("\n✗ Something's wrong with the reference model setup!")
return all_good
# Test it
verify_reference_frozen(policy_model, reference_model)Reference Model Verification:
Any parameter requires grad? False
✓ Should be False (frozen)
Weights equal to policy? True
✓ Should be True initially (same starting point)
In eval mode? True
✓ Should be True (no dropout, etc.)
✓ Reference model is properly frozen!
TrueMonitoring Drift During Training¶
One more useful tool: tracking how far the policy has drifted from the reference.
During RLHF training, you want to keep an eye on this. If the policy drifts too far, you might need to increase β (the KL penalty weight). If it barely drifts at all, you might be able to decrease β and let it explore more.
Here’s a simple way to measure drift in weight space:
def compute_weight_divergence(policy_model, reference_model):
"""
Compute how far policy weights have diverged from reference.
This gives you a single number you can track during training.
Useful for detecting if the policy is drifting too far too fast.
"""
total_diff = 0.0
total_norm = 0.0
# Sum up the norm of differences across all parameters
for (name, p_param), (_, r_param) in zip(
policy_model.named_parameters(),
reference_model.named_parameters()
):
diff = (p_param - r_param).norm().item() # L2 norm of difference
norm = r_param.norm().item() # L2 norm of reference
total_diff += diff
total_norm += norm
# Relative divergence normalizes by reference magnitude
relative_divergence = total_diff / (total_norm + 1e-8)
return {
'absolute_divergence': total_diff,
'relative_divergence': relative_divergence
}
# Check initial divergence (should be ~0 since we haven't trained yet)
divergence = compute_weight_divergence(policy_model, reference_model)
print(f"Weight divergence from reference:")
print(f" Absolute: {divergence['absolute_divergence']:.6f}")
print(f" Relative: {divergence['relative_divergence']:.6f}")
print(f"\nShould be zero initially. After training, this will increase.")
print(f"If it increases too much, you're drifting far from the reference!")
print(f"\nTypical values during RLHF:")
print(f" Early training: 0.01 - 0.05")
print(f" Mid training: 0.05 - 0.15")
print(f" Late training: 0.15 - 0.30")
print(f" Too much drift: > 0.50 (might be reward hacking!)")Weight divergence from reference:
Absolute: 0.000000
Relative: 0.000000
Should be zero initially. After training, this will increase.
If it increases too much, you're drifting far from the reference!
Typical values during RLHF:
Early training: 0.01 - 0.05
Mid training: 0.05 - 0.15
Late training: 0.15 - 0.30
Too much drift: > 0.50 (might be reward hacking!)
Wrapping Up: The Complete Picture¶
Let’s recap what we’ve learned about reference models:
What it is: A frozen copy of your policy model at the start of RLHF training.
Why you need it: Without it, the policy drifts into reward hacking, mode collapse, or language degradation. The KL penalty keeps it grounded.
The cost: Double the memory usage (though we can mitigate with FP16 or CPU offloading).
The math: We compute KL divergence as the difference in log probabilities between policy and reference, then penalize the policy for drifting too far.
How to use it: Create it once at the start of training, freeze it completely, and use it to compute the KL penalty at every training step.
Now we’ve covered the complete RLHF pipeline:
✓ Supervised fine-tuning (SFT) to teach the model to follow instructions
✓ Reward model training to learn human preferences
✓ Reference model creation to prevent drift
✓ (Next up: putting it all together with PPO)
But there’s actually a simpler way to do all of this. It’s called Direct Preference Optimization (DPO), and it skips the reward model and the reference model entirely.
Let’s look at that next.