Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Reasoning Distillation

We’ve built powerful reasoning systems: chain-of-thought, process reward models, MCTS, GRPO. But these techniques work best with large models (70B+). What about reasoning on a phone? In a browser? On a tiny GPU?

Distillation transfers the reasoning patterns from a large teacher to a small student.

The DeepSeek Discovery

From the DeepSeek-R1 paper:

The reasoning patterns of larger models can be distilled into smaller models, resulting in better performance compared to the reasoning patterns discovered through RL on small models.

In other words:

  • Train a 70B model to reason with RL → good reasoning emerges

  • Distill to a 7B model → 7B inherits the reasoning patterns

  • Result: 7B with distillation > 7B with direct RL

The small model can’t discover complex reasoning patterns on its own, but it can learn to imitate them.

Types of Distillation

1. Standard Knowledge Distillation

Train student to match teacher’s output distributions.

LKD=KL(PteacherPstudent)\mathcal{L}_{\text{KD}} = \text{KL}(P_{\text{teacher}} || P_{\text{student}})

2. Reasoning Trace Distillation

Train student on teacher’s step-by-step solutions.

Ltrace=logPstudent(reasoning trace)\mathcal{L}_{\text{trace}} = -\log P_{\text{student}}(\text{reasoning trace})

3. Behavioral Cloning

Just train student to produce the same final outputs.

For reasoning, trace distillation works best. The student learns not just what answer to give, but how to think.

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List
import numpy as np

# Load teacher and student models
# Teacher: larger model, Student: smaller model
print("Loading models...")

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
teacher = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct", dtype="auto")
student = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", dtype="auto")

device = "cuda" if torch.cuda.is_available() else "cpu"
teacher = teacher.to(device)
student = student.to(device)

teacher.eval()  # Teacher is frozen
# Student will be trained

print(f"Loaded on {device}")
print(f"Teacher: Qwen2.5-1.5B-Instruct")
print(f"Student: Qwen2.5-0.5B-Instruct")
Loading models...
Loaded on cuda
Teacher: Qwen2.5-1.5B-Instruct
Student: Qwen2.5-0.5B-Instruct

Generating Teacher Reasoning Traces

First, we need to collect high-quality reasoning traces from the teacher model.

What’s a Trace?

A reasoning trace is the complete step-by-step solution a model generates when solving a problem. Not just the final answer—the whole internal monologue.

For example, when asked “What is 15 + 28?”, a trace might look like:

Step 1: I'll break this into parts: 15 = 10 + 5
Step 2: Add the tens: 10 + 20 = 30
Step 3: Add the ones: 5 + 8 = 13
Step 4: Combine: 30 + 13 = 43

The trace is everything from “Step 1” to “43”. It’s the how, not just the what.

In chain-of-thought prompting, we showed models a few example traces and asked them to produce their own. In distillation, we’re taking thousands of high-quality traces from a smart teacher model and training a smaller student to reproduce that same reasoning style.

Think of it like showing a student worked examples in a textbook. The trace is the worked example—every intermediate calculation, every logical step laid out explicitly.

def generate_teacher_traces(teacher, tokenizer, problems: List[str],
                            n_per_problem: int = 3,
                            max_tokens: int = 150) -> List[dict]:
    """
    Generate reasoning traces from the teacher model.
    
    We'll generate multiple traces per problem and filter for correctness.
    """
    traces = []
    
    for problem in problems:
        prompt = f"Problem: {problem}\n\nSolution: Let me solve this step by step.\n"
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        for _ in range(n_per_problem):
            with torch.no_grad():
                outputs = teacher.generate(
                    **inputs,
                    max_new_tokens=max_tokens,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=tokenizer.eos_token_id,
                )
            
            full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            trace = full_text[len(prompt):]
            
            traces.append({
                "problem": problem,
                "prompt": prompt,
                "trace": trace,
                "full_text": prompt + trace
            })
    
    return traces


# Generate some traces
problems = [
    "What is 15 + 28?",
    "If a train travels 60 miles in 2 hours, what is its speed?",
    "A store has 50 items. They sell 20% of them. How many are left?",
]

print("Generating teacher traces...")
teacher_traces = generate_teacher_traces(teacher, tokenizer, problems, n_per_problem=2)

print(f"\nGenerated {len(teacher_traces)} traces")
print("\nExample trace:")
print("="*60)
print(teacher_traces[0]["full_text"][:300] + "...")
Generating teacher traces...

Generated 6 traces

Example trace:
============================================================
Problem: What is 15 + 28?

Solution: Let me solve this step by step.
Step 1: Start with the first number, which is 15. 
Step 2: Add 20 to it because 20 + 5 = 25
Step 3: Now add another 8 because 25 + 8 = 33

So, 15 + 28 equals 43.

Final answer: 15 + 28 = 43

Let's verify our solution:

15 + 28 = (1...

Trace Distillation Loss

The simplest form: train the student to produce the same reasoning traces as the teacher.

This is just supervised fine-tuning on teacher-generated data!

def compute_trace_distillation_loss(student, tokenizer, 
                                     trace: dict) -> torch.Tensor:
    """
    Compute loss for matching a teacher trace.
    
    This is just cross-entropy on the reasoning steps.
    """
    full_text = trace["full_text"]
    prompt_len = len(tokenizer(trace["prompt"])["input_ids"])
    
    inputs = tokenizer(full_text, return_tensors="pt").to(device)
    
    # Forward pass
    outputs = student(**inputs, labels=inputs["input_ids"])
    
    # We only care about loss on the reasoning trace, not the prompt
    # In practice, we'd mask the prompt tokens
    # For simplicity, we'll use the full loss here
    
    return outputs.loss


# Test
loss = compute_trace_distillation_loss(student, tokenizer, teacher_traces[0])
print(f"Distillation loss: {loss.item():.4f}")
Distillation loss: 0.9808

This loss (~0.97) is actually pretty good for a first attempt! Remember, this is cross-entropy—it’s measuring how confident the student is about each next token. A loss of 0 would mean perfect certainty (impossible), while a loss around 1 means the student is reasonably confident but still learning.

For context: random guessing across a 32,000-token vocabulary would give you a loss around 10. So 0.97 means the student already has a decent prior—it’s not flailing randomly. After training on many examples, we’d expect this to drop closer to 0.5 or lower.

Token-Level Knowledge Distillation

A more sophisticated approach: match the teacher’s probability distribution at each token position.

LKD=tKL(PT(yty<t)τPS(yty<t)τ)\mathcal{L}_{\text{KD}} = \sum_t \text{KL}\left( \frac{P_T(y_t|y_{<t})}{\tau} \bigg|\bigg| \frac{P_S(y_t|y_{<t})}{\tau} \right)

Where τ\tau is a temperature that softens the distributions.

def knowledge_distillation_loss(teacher, student, tokenizer,
                                 text: str, temperature: float = 2.0,
                                 alpha: float = 0.5) -> torch.Tensor:
    """
    Token-level knowledge distillation.
    
    Combines:
    1. KL divergence from teacher distributions
    2. Hard target cross-entropy
    
    Args:
        teacher: Teacher model (frozen)
        student: Student model (training)
        tokenizer: Tokenizer
        text: Text to distill on
        temperature: Softening temperature
        alpha: Weight on distillation vs. hard targets
    
    Returns:
        Combined loss
    """
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # Get teacher logits
    with torch.no_grad():
        teacher_outputs = teacher(**inputs)
        teacher_logits = teacher_outputs.logits
    
    # Get student logits
    student_outputs = student(**inputs)
    student_logits = student_outputs.logits
    
    # Soft targets (temperature-scaled)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
    student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
    
    # KL divergence loss
    kl_loss = F.kl_div(
        student_log_probs[:, :-1, :],  # Predict next token
        teacher_probs[:, :-1, :],
        reduction='batchmean'
    )
    
    # Hard target loss (standard cross-entropy)
    hard_loss = F.cross_entropy(
        student_logits[:, :-1, :].reshape(-1, student_logits.size(-1)),
        inputs["input_ids"][:, 1:].reshape(-1)
    )
    
    # Combined loss
    # Scale KL by T^2 (standard practice)
    total_loss = alpha * (temperature ** 2) * kl_loss + (1 - alpha) * hard_loss
    
    return total_loss, kl_loss, hard_loss


# Test
test_text = teacher_traces[0]["full_text"]
total, kl, hard = knowledge_distillation_loss(
    teacher, student, tokenizer, test_text
)

print(f"Knowledge distillation losses:")
print(f"  KL loss: {kl.item():.4f}")
print(f"  Hard target loss: {hard.item():.4f}")
print(f"  Total: {total.item():.4f}")
Knowledge distillation losses:
  KL loss: 57.7500
  Hard target loss: 0.9805
  Total: 116.0000

Wait, why is the total loss (121.5) so much bigger than either component?

Because of the T2T^2 scaling factor! When we use temperature T=2T=2, we multiply the KL loss by 22=42^2 = 4. This is standard practice in distillation—it balances the contribution of the soft targets against the hard targets.

Breaking it down:

  • KL loss (60.5): How different are the student’s probabilities from the teacher’s? This is naturally larger because we’re comparing full probability distributions across thousands of tokens.

  • Hard target loss (0.97): Standard cross-entropy against the actual next token. Same scale as the trace distillation loss above.

  • Total (121.5): αT2KL+(1α)hard=0.5460.5+0.50.97121.5\alpha \cdot T^2 \cdot \text{KL} + (1-\alpha) \cdot \text{hard} = 0.5 \cdot 4 \cdot 60.5 + 0.5 \cdot 0.97 \approx 121.5

The total will decrease as training progresses and the student learns to match the teacher’s distribution better.

Training Loop for Distillation

def train_distillation_epoch(teacher, student, tokenizer,
                              traces: List[dict], optimizer,
                              temperature: float = 2.0,
                              alpha: float = 0.5) -> float:
    """
    Train student for one epoch on teacher traces.
    """
    student.train()
    total_loss = 0.0
    
    for trace in traces:
        optimizer.zero_grad()
        
        loss, _, _ = knowledge_distillation_loss(
            teacher, student, tokenizer,
            trace["full_text"],
            temperature=temperature,
            alpha=alpha
        )
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(traces)


# Train for a few epochs
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-5)

print("Training student with distillation...")
print("="*50)

for epoch in range(3):
    loss = train_distillation_epoch(
        teacher, student, tokenizer,
        teacher_traces, optimizer
    )
    print(f"Epoch {epoch+1}: Loss = {loss:.4f}")

print("\nDistillation complete!")
Training student with distillation...
==================================================
Epoch 1: Loss = 106.2500
Epoch 2: Loss = 83.0833
Epoch 3: Loss = 74.0417

Distillation complete!

These losses look huge (110 → 84 → 75) but they’re actually fine! Remember, we’re averaging the distillation loss across all traces, and that loss includes the T2T^2 scaling we saw above.

What matters is the trend: loss is decreasing consistently. The student is learning to match the teacher’s reasoning patterns better with each epoch. In a real training run with thousands of traces, you’d see this drop much further (into the 20s or lower).

Also note: we’re using the knowledge distillation loss here (with the KL component), not the simpler trace distillation loss. So the scale is naturally larger. If we were just doing trace SFT, we’d see losses closer to the 0.97 we got in the first test.

Quality Filtering

Not all teacher traces are worth imitating. We should filter for:

  1. Correct answers — Don’t teach wrong reasoning

  2. Clear steps — Mumbled reasoning is hard to learn from

  3. Diverse approaches — Multiple ways to solve problems

import re

def filter_traces(traces: List[dict], correct_answers: dict = None) -> List[dict]:
    """
    Filter teacher traces for quality.
    
    Args:
        traces: List of trace dicts
        correct_answers: Dict mapping problems to correct answers
    
    Returns:
        Filtered list of high-quality traces
    """
    filtered = []
    
    for trace in traces:
        text = trace["trace"]
        problem = trace["problem"]
        
        # Filter 1: Must have step-by-step structure
        has_steps = any(marker in text.lower() 
                       for marker in ['step', 'first', 'then', 'next', 'finally'])
        if not has_steps:
            continue
        
        # Filter 2: Must have reasonable length
        if len(text.split()) < 20 or len(text.split()) > 300:
            continue
        
        # Filter 3: Check correctness if we have answers
        if correct_answers and problem in correct_answers:
            correct = str(correct_answers[problem])
            if correct not in text:
                continue
        
        filtered.append(trace)
    
    return filtered


# Example filtering
correct_answers = {
    "What is 15 + 28?": "43",
    "If a train travels 60 miles in 2 hours, what is its speed?": "30",
    "A store has 50 items. They sell 20% of them. How many are left?": "40",
}

filtered = filter_traces(teacher_traces, correct_answers)
print(f"Filtered: {len(filtered)}/{len(teacher_traces)} traces kept")
Filtered: 5/6 traces kept

Results from DeepSeek

From the DeepSeek-R1 paper:

ModelMethodAIME 2024MATH
Qwen2.5-7BBase3.3%75.5%
Qwen2.5-7B+ RL alone10.0%79.3%
Qwen2.5-7B+ R1 distillation26.7%83.9%
Qwen2.5-32B+ R1 distillation43.3%90.2%

Key insight: A 7B model with distillation dramatically outperforms a 7B model trained with RL alone. The reasoning patterns from the larger model transfer!

What We’ve Learned

Reasoning distillation transfers thinking patterns from large to small models:

  1. Generate high-quality reasoning traces from teacher

  2. Filter for correctness and clarity

  3. Train student to reproduce the traces (SFT or KD)

The key insight:

Small models can’t discover complex reasoning on their own, but they can learn to imitate it.

Two loss functions:

  • Trace SFT: L=logPS(trace)\mathcal{L} = -\log P_S(\text{trace})

  • KD: L=αT2KL(PT/TPS/T)+(1α)CE\mathcal{L} = \alpha \cdot T^2 \cdot \text{KL}(P_T/T || P_S/T) + (1-\alpha) \cdot \text{CE}

This completes our journey through reasoning techniques!

Summary of the Section

We covered:

  1. Chain-of-Thought — Think step by step

  2. Self-Consistency — Sample many, vote

  3. Tree of Thoughts — Explore and backtrack

  4. Process Reward Models — Score each step

  5. Best-of-N — Generate and verify

  6. MCTS — Smart search

  7. Budget Forcing — Control thinking length

  8. GRPO — RL without a critic

  9. Distillation — Transfer to smaller models

These techniques, combined, power the reasoning capabilities of models like o1 and DeepSeek-R1.