We’ve built powerful reasoning systems: chain-of-thought, process reward models, MCTS, GRPO. But these techniques work best with large models (70B+). What about reasoning on a phone? In a browser? On a tiny GPU?
Distillation transfers the reasoning patterns from a large teacher to a small student.
The DeepSeek Discovery¶
From the DeepSeek-R1 paper:
The reasoning patterns of larger models can be distilled into smaller models, resulting in better performance compared to the reasoning patterns discovered through RL on small models.
In other words:
Train a 70B model to reason with RL → good reasoning emerges
Distill to a 7B model → 7B inherits the reasoning patterns
Result: 7B with distillation > 7B with direct RL
The small model can’t discover complex reasoning patterns on its own, but it can learn to imitate them.
Types of Distillation¶
1. Standard Knowledge Distillation¶
Train student to match teacher’s output distributions.
2. Reasoning Trace Distillation¶
Train student on teacher’s step-by-step solutions.
3. Behavioral Cloning¶
Just train student to produce the same final outputs.
For reasoning, trace distillation works best. The student learns not just what answer to give, but how to think.
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List
import numpy as np
# Load teacher and student models
# Teacher: larger model, Student: smaller model
print("Loading models...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
teacher = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct", dtype="auto")
student = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", dtype="auto")
device = "cuda" if torch.cuda.is_available() else "cpu"
teacher = teacher.to(device)
student = student.to(device)
teacher.eval() # Teacher is frozen
# Student will be trained
print(f"Loaded on {device}")
print(f"Teacher: Qwen2.5-1.5B-Instruct")
print(f"Student: Qwen2.5-0.5B-Instruct")Loading models...
Loaded on cuda
Teacher: Qwen2.5-1.5B-Instruct
Student: Qwen2.5-0.5B-Instruct
Generating Teacher Reasoning Traces¶
First, we need to collect high-quality reasoning traces from the teacher model.
What’s a Trace?¶
A reasoning trace is the complete step-by-step solution a model generates when solving a problem. Not just the final answer—the whole internal monologue.
For example, when asked “What is 15 + 28?”, a trace might look like:
Step 1: I'll break this into parts: 15 = 10 + 5
Step 2: Add the tens: 10 + 20 = 30
Step 3: Add the ones: 5 + 8 = 13
Step 4: Combine: 30 + 13 = 43The trace is everything from “Step 1” to “43”. It’s the how, not just the what.
In chain-of-thought prompting, we showed models a few example traces and asked them to produce their own. In distillation, we’re taking thousands of high-quality traces from a smart teacher model and training a smaller student to reproduce that same reasoning style.
Think of it like showing a student worked examples in a textbook. The trace is the worked example—every intermediate calculation, every logical step laid out explicitly.
def generate_teacher_traces(teacher, tokenizer, problems: List[str],
n_per_problem: int = 3,
max_tokens: int = 150) -> List[dict]:
"""
Generate reasoning traces from the teacher model.
We'll generate multiple traces per problem and filter for correctness.
"""
traces = []
for problem in problems:
prompt = f"Problem: {problem}\n\nSolution: Let me solve this step by step.\n"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
for _ in range(n_per_problem):
with torch.no_grad():
outputs = teacher.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
trace = full_text[len(prompt):]
traces.append({
"problem": problem,
"prompt": prompt,
"trace": trace,
"full_text": prompt + trace
})
return traces
# Generate some traces
problems = [
"What is 15 + 28?",
"If a train travels 60 miles in 2 hours, what is its speed?",
"A store has 50 items. They sell 20% of them. How many are left?",
]
print("Generating teacher traces...")
teacher_traces = generate_teacher_traces(teacher, tokenizer, problems, n_per_problem=2)
print(f"\nGenerated {len(teacher_traces)} traces")
print("\nExample trace:")
print("="*60)
print(teacher_traces[0]["full_text"][:300] + "...")Generating teacher traces...
Generated 6 traces
Example trace:
============================================================
Problem: What is 15 + 28?
Solution: Let me solve this step by step.
Step 1: Start with the first number, which is 15.
Step 2: Add 20 to it because 20 + 5 = 25
Step 3: Now add another 8 because 25 + 8 = 33
So, 15 + 28 equals 43.
Final answer: 15 + 28 = 43
Let's verify our solution:
15 + 28 = (1...
Trace Distillation Loss¶
The simplest form: train the student to produce the same reasoning traces as the teacher.
This is just supervised fine-tuning on teacher-generated data!
def compute_trace_distillation_loss(student, tokenizer,
trace: dict) -> torch.Tensor:
"""
Compute loss for matching a teacher trace.
This is just cross-entropy on the reasoning steps.
"""
full_text = trace["full_text"]
prompt_len = len(tokenizer(trace["prompt"])["input_ids"])
inputs = tokenizer(full_text, return_tensors="pt").to(device)
# Forward pass
outputs = student(**inputs, labels=inputs["input_ids"])
# We only care about loss on the reasoning trace, not the prompt
# In practice, we'd mask the prompt tokens
# For simplicity, we'll use the full loss here
return outputs.loss
# Test
loss = compute_trace_distillation_loss(student, tokenizer, teacher_traces[0])
print(f"Distillation loss: {loss.item():.4f}")Distillation loss: 0.9808
This loss (~0.97) is actually pretty good for a first attempt! Remember, this is cross-entropy—it’s measuring how confident the student is about each next token. A loss of 0 would mean perfect certainty (impossible), while a loss around 1 means the student is reasonably confident but still learning.
For context: random guessing across a 32,000-token vocabulary would give you a loss around 10. So 0.97 means the student already has a decent prior—it’s not flailing randomly. After training on many examples, we’d expect this to drop closer to 0.5 or lower.
Token-Level Knowledge Distillation¶
A more sophisticated approach: match the teacher’s probability distribution at each token position.
Where is a temperature that softens the distributions.
def knowledge_distillation_loss(teacher, student, tokenizer,
text: str, temperature: float = 2.0,
alpha: float = 0.5) -> torch.Tensor:
"""
Token-level knowledge distillation.
Combines:
1. KL divergence from teacher distributions
2. Hard target cross-entropy
Args:
teacher: Teacher model (frozen)
student: Student model (training)
tokenizer: Tokenizer
text: Text to distill on
temperature: Softening temperature
alpha: Weight on distillation vs. hard targets
Returns:
Combined loss
"""
inputs = tokenizer(text, return_tensors="pt").to(device)
# Get teacher logits
with torch.no_grad():
teacher_outputs = teacher(**inputs)
teacher_logits = teacher_outputs.logits
# Get student logits
student_outputs = student(**inputs)
student_logits = student_outputs.logits
# Soft targets (temperature-scaled)
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
# KL divergence loss
kl_loss = F.kl_div(
student_log_probs[:, :-1, :], # Predict next token
teacher_probs[:, :-1, :],
reduction='batchmean'
)
# Hard target loss (standard cross-entropy)
hard_loss = F.cross_entropy(
student_logits[:, :-1, :].reshape(-1, student_logits.size(-1)),
inputs["input_ids"][:, 1:].reshape(-1)
)
# Combined loss
# Scale KL by T^2 (standard practice)
total_loss = alpha * (temperature ** 2) * kl_loss + (1 - alpha) * hard_loss
return total_loss, kl_loss, hard_loss
# Test
test_text = teacher_traces[0]["full_text"]
total, kl, hard = knowledge_distillation_loss(
teacher, student, tokenizer, test_text
)
print(f"Knowledge distillation losses:")
print(f" KL loss: {kl.item():.4f}")
print(f" Hard target loss: {hard.item():.4f}")
print(f" Total: {total.item():.4f}")Knowledge distillation losses:
KL loss: 57.7500
Hard target loss: 0.9805
Total: 116.0000
Wait, why is the total loss (121.5) so much bigger than either component?
Because of the scaling factor! When we use temperature , we multiply the KL loss by . This is standard practice in distillation—it balances the contribution of the soft targets against the hard targets.
Breaking it down:
KL loss (60.5): How different are the student’s probabilities from the teacher’s? This is naturally larger because we’re comparing full probability distributions across thousands of tokens.
Hard target loss (0.97): Standard cross-entropy against the actual next token. Same scale as the trace distillation loss above.
Total (121.5):
The total will decrease as training progresses and the student learns to match the teacher’s distribution better.
Training Loop for Distillation¶
def train_distillation_epoch(teacher, student, tokenizer,
traces: List[dict], optimizer,
temperature: float = 2.0,
alpha: float = 0.5) -> float:
"""
Train student for one epoch on teacher traces.
"""
student.train()
total_loss = 0.0
for trace in traces:
optimizer.zero_grad()
loss, _, _ = knowledge_distillation_loss(
teacher, student, tokenizer,
trace["full_text"],
temperature=temperature,
alpha=alpha
)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(traces)
# Train for a few epochs
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-5)
print("Training student with distillation...")
print("="*50)
for epoch in range(3):
loss = train_distillation_epoch(
teacher, student, tokenizer,
teacher_traces, optimizer
)
print(f"Epoch {epoch+1}: Loss = {loss:.4f}")
print("\nDistillation complete!")Training student with distillation...
==================================================
Epoch 1: Loss = 106.2500
Epoch 2: Loss = 83.0833
Epoch 3: Loss = 74.0417
Distillation complete!
These losses look huge (110 → 84 → 75) but they’re actually fine! Remember, we’re averaging the distillation loss across all traces, and that loss includes the scaling we saw above.
What matters is the trend: loss is decreasing consistently. The student is learning to match the teacher’s reasoning patterns better with each epoch. In a real training run with thousands of traces, you’d see this drop much further (into the 20s or lower).
Also note: we’re using the knowledge distillation loss here (with the KL component), not the simpler trace distillation loss. So the scale is naturally larger. If we were just doing trace SFT, we’d see losses closer to the 0.97 we got in the first test.
Quality Filtering¶
Not all teacher traces are worth imitating. We should filter for:
Correct answers — Don’t teach wrong reasoning
Clear steps — Mumbled reasoning is hard to learn from
Diverse approaches — Multiple ways to solve problems
import re
def filter_traces(traces: List[dict], correct_answers: dict = None) -> List[dict]:
"""
Filter teacher traces for quality.
Args:
traces: List of trace dicts
correct_answers: Dict mapping problems to correct answers
Returns:
Filtered list of high-quality traces
"""
filtered = []
for trace in traces:
text = trace["trace"]
problem = trace["problem"]
# Filter 1: Must have step-by-step structure
has_steps = any(marker in text.lower()
for marker in ['step', 'first', 'then', 'next', 'finally'])
if not has_steps:
continue
# Filter 2: Must have reasonable length
if len(text.split()) < 20 or len(text.split()) > 300:
continue
# Filter 3: Check correctness if we have answers
if correct_answers and problem in correct_answers:
correct = str(correct_answers[problem])
if correct not in text:
continue
filtered.append(trace)
return filtered
# Example filtering
correct_answers = {
"What is 15 + 28?": "43",
"If a train travels 60 miles in 2 hours, what is its speed?": "30",
"A store has 50 items. They sell 20% of them. How many are left?": "40",
}
filtered = filter_traces(teacher_traces, correct_answers)
print(f"Filtered: {len(filtered)}/{len(teacher_traces)} traces kept")Filtered: 5/6 traces kept
Results from DeepSeek¶
From the DeepSeek-R1 paper:
| Model | Method | AIME 2024 | MATH |
|---|---|---|---|
| Qwen2.5-7B | Base | 3.3% | 75.5% |
| Qwen2.5-7B | + RL alone | 10.0% | 79.3% |
| Qwen2.5-7B | + R1 distillation | 26.7% | 83.9% |
| Qwen2.5-32B | + R1 distillation | 43.3% | 90.2% |
Key insight: A 7B model with distillation dramatically outperforms a 7B model trained with RL alone. The reasoning patterns from the larger model transfer!
What We’ve Learned¶
Reasoning distillation transfers thinking patterns from large to small models:
Generate high-quality reasoning traces from teacher
Filter for correctness and clarity
Train student to reproduce the traces (SFT or KD)
The key insight:
Small models can’t discover complex reasoning on their own, but they can learn to imitate it.
Two loss functions:
Trace SFT:
KD:
This completes our journey through reasoning techniques!
Summary of the Section¶
We covered:
Chain-of-Thought — Think step by step
Self-Consistency — Sample many, vote
Tree of Thoughts — Explore and backtrack
Process Reward Models — Score each step
Best-of-N — Generate and verify
MCTS — Smart search
Budget Forcing — Control thinking length
GRPO — RL without a critic
Distillation — Transfer to smaller models
These techniques, combined, power the reasoning capabilities of models like o1 and DeepSeek-R1.
