Alright. Deep breath.
We’ve learned about tokenization, formatting, and all the prep work. Now it’s time to actually train the thing.
We’re going to take a pre-trained model and teach it new tricks through supervised fine-tuning.
The Training Journey (A Roadmap)¶
Think of training as a road trip. You need:
A vehicle (the pre-trained model and tokenizer)
Fuel (the training data, properly formatted)
A route (the data loader that feeds examples in batches)
Navigation (the optimizer and learning rate scheduler)
The actual driving (the training loop where learning happens)
Rest stops (evaluation checkpoints)
Your destination (the fine-tuned model, saved and ready to use)
Each step matters. Skip one and you’re stranded on the side of the road with a confused transformer.
Let’s build this piece by piece.
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np
# Setup device - GPU if available, CPU otherwise
# (Training on CPU is like walking to San Francisco. Technically possible, but...)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")Using device: cuda
GPU: Radeon RX 7900 XTX
Memory: 25.75 GB
The Dataset Class: Preparing Training Examples¶
Remember our Alpaca template from before? Now we need to wrap it in a PyTorch Dataset class.
Why? Because the training loop doesn’t want to deal with raw data. It wants a nice, standardized interface where it can say “give me example #42” and get back properly formatted tensors.
Think of this as a restaurant kitchen. The chef (training loop) doesn’t want to deal with whole chickens and raw vegetables. They want prepped ingredients, measured and ready to cook.
ALPACA_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
{response}"""
class SFTDataset(Dataset):
"""Dataset for supervised fine-tuning.
This class does three critical things:
1. Formats examples with the Alpaca template
2. Tokenizes text into numbers the model understands
3. Creates labels that mask the prompt (we only train on responses!)
"""
def __init__(self, data, tokenizer, max_length=512):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
"""How many examples do we have?"""
return len(self.data)
def __getitem__(self, idx):
"""Get one training example, fully prepped and ready."""
item = self.data[idx]
# Step 1: Format with our template
formatted = ALPACA_TEMPLATE.format(
instruction=item['instruction'],
response=item['output']
)
# Step 2: Figure out where the response starts
# We need this because we DON'T want to train on the instruction part
# Only the response should contribute to the loss
prompt = ALPACA_TEMPLATE.format(
instruction=item['instruction'],
response='' # Empty response to find where it would start
)
# Step 3: Tokenize the full text
full_tokens = self.tokenizer(
formatted,
max_length=self.max_length,
truncation=True, # Cut off if too long
padding='max_length', # Pad to consistent length
return_tensors='pt'
)
# Step 4: Tokenize just the prompt to find response boundary
prompt_tokens = self.tokenizer(
prompt,
max_length=self.max_length,
truncation=True,
return_tensors='pt'
)
response_start = prompt_tokens['input_ids'].shape[1]
# Step 5: Create labels with masking
# Here's the key insight: we copy the input_ids but mask the prompt
# The -100 value tells PyTorch "don't calculate loss for these tokens"
labels = full_tokens['input_ids'].clone().squeeze(0)
labels[:response_start] = -100 # Mask the instruction
# Also mask padding tokens (we don't want to train on padding!)
labels[labels == self.tokenizer.pad_token_id] = -100
# Return everything the training loop needs
return {
'input_ids': full_tokens['input_ids'].squeeze(0),
'attention_mask': full_tokens['attention_mask'].squeeze(0),
'labels': labels
}
# Let's test it with a sample
print("Dataset class defined successfully!")
print("\nKey insight: We mask the instruction with -100 so the model only learns")
print("to predict the response. This is supervised fine-tuning in action.")Dataset class defined successfully!
Key insight: We mask the instruction with -100 so the model only learns
to predict the response. This is supervised fine-tuning in action.
Training Configuration: All the Knobs and Dials¶
Before we train, we need to make some decisions. How fast should we learn? How many examples at once? How many times through the data?
These are called hyperparameters (fancy name for “settings we tune by hand”). They can make or break your training run.
Too fast? The model goes haywire. Too slow? You’re waiting until the heat death of the universe. Too much data at once? Out of memory. Too little? Noisy, unstable training.
The settings below are reasonable defaults for our small demo. In the real world, you’d spend days (weeks?) tuning these.
from dataclasses import dataclass
@dataclass
class SFTConfig:
"""Configuration for SFT training.
Let's break down what each setting does:
"""
# Which model to start from
model_name: str = "gpt2"
# Maximum sequence length (longer = more memory)
max_length: int = 512
# Batch size: how many examples to process at once
# Bigger = faster but needs more memory
batch_size: int = 4
# Learning rate: how big of steps to take when updating weights
# This is the single most important hyperparameter
learning_rate: float = 2e-4 # 0.0002 - small steps, safe and steady
# Epochs: how many times to loop through the entire dataset
# One epoch = seeing every training example once
num_epochs: int = 3
# Warmup steps: gradually increase learning rate at the start
# Prevents the model from making crazy updates early on
warmup_steps: int = 100
# Gradient accumulation: simulate bigger batches without more memory
# Process this many batches before updating weights
# Effective batch size = batch_size * gradient_accumulation_steps
gradient_accumulation_steps: int = 4
# Gradient clipping: prevent exploding gradients
# If gradients get bigger than this, scale them down
max_grad_norm: float = 1.0
# How often to log progress
logging_steps: int = 10
# How often to check validation loss
eval_steps: int = 100
# How often to save checkpoints
save_steps: int = 500
# Where to save the model
output_dir: str = "./sft_output"
config = SFTConfig()
print("Training configuration:")
print("=" * 60)
for k, v in vars(config).items():
print(f" {k:.<35} {v}")
print("=" * 60)
print(f"\nEffective batch size: {config.batch_size * config.gradient_accumulation_steps}")
print("(That's how many examples we process before each weight update)")Training configuration:
============================================================
model_name......................... gpt2
max_length......................... 512
batch_size......................... 4
learning_rate...................... 0.0002
num_epochs......................... 3
warmup_steps....................... 100
gradient_accumulation_steps........ 4
max_grad_norm...................... 1.0
logging_steps...................... 10
eval_steps......................... 100
save_steps......................... 500
output_dir......................... ./sft_output
============================================================
Effective batch size: 16
(That's how many examples we process before each weight update)
The Training Loop¶
Okay, here’s the heart of it all. The training loop.
This is where we actually teach the model. It’s a dance with three steps, repeated thousands of times:
Forward pass: Show the model an example, see what it predicts
Backward pass: Calculate how wrong it was, compute gradients
Update: Adjust the weights to be a little less wrong next time
Rinse and repeat until the model gets good (or you run out of patience/GPU credits).
The code below looks long, but it’s just those three steps wrapped in careful bookkeeping. Let me walk you through it.
def train_sft(model, tokenizer, train_dataset, eval_dataset, config):
"""Complete SFT training loop.
This function orchestrates the entire training process.
It's like conducting an orchestra - lots of moving parts, all needing
to work together in harmony.
"""
# Create data loaders
# These batch up our data and shuffle it each epoch
# Think of this as the assembly line that feeds examples to the model
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True, # Randomize order each epoch - helps training
num_workers=0 # Parallel data loading (0 = use main process)
)
eval_loader = DataLoader(
eval_dataset,
batch_size=config.batch_size,
shuffle=False # No need to shuffle validation data
)
# Setup optimizer
# The optimizer is what actually updates the model weights
# AdamW is the gold standard for transformer training
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=0.01 # Regularization to prevent overfitting
)
# Learning rate scheduler
# We don't use a fixed learning rate - we adjust it over time
# Start with warmup (gradual increase), then linear decay
total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=config.warmup_steps,
num_training_steps=total_steps
)
# Training state
model.train() # Put model in training mode (enables dropout, etc.)
global_step = 0 # Total number of weight updates
best_eval_loss = float('inf') # Track best model
# The main training loop - iterate over epochs
# An epoch = one complete pass through the training data
for epoch in range(config.num_epochs):
epoch_loss = 0
# Progress bar for visual feedback (because watching loss decrease is satisfying)
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
# Iterate over batches within this epoch
for step, batch in enumerate(progress_bar):
# Move batch to GPU (if available)
batch = {k: v.to(device) for k, v in batch.items()}
# ===== FORWARD PASS =====
# Feed the input through the model, get predictions and loss
outputs = model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels'] # The model computes loss for us!
)
# Scale loss by accumulation steps
# Why? Because we're going to accumulate gradients across multiple batches
loss = outputs.loss / config.gradient_accumulation_steps
# ===== BACKWARD PASS =====
# Compute gradients - this is where the learning happens
# PyTorch automagically calculates how to adjust each weight
loss.backward()
epoch_loss += loss.item() * config.gradient_accumulation_steps
# ===== UPDATE WEIGHTS =====
# Only update every gradient_accumulation_steps
# This simulates a larger batch size without using more memory
if (step + 1) % config.gradient_accumulation_steps == 0:
# Clip gradients to prevent explosions
# Sometimes gradients get REALLY big - this caps them
torch.nn.utils.clip_grad_norm_(
model.parameters(),
config.max_grad_norm
)
# Take the optimization step
optimizer.step()
scheduler.step() # Update learning rate
optimizer.zero_grad() # Clear gradients for next batch
global_step += 1
# Log progress
if global_step % config.logging_steps == 0:
avg_loss = epoch_loss / (step + 1)
progress_bar.set_postfix({
'loss': f'{avg_loss:.4f}',
'lr': f'{scheduler.get_last_lr()[0]:.2e}'
})
# End of epoch - check how we're doing on validation data
eval_loss = evaluate(model, eval_loader, device)
print(f"\nEpoch {epoch+1} - Train Loss: {epoch_loss/len(train_loader):.4f}, Eval Loss: {eval_loss:.4f}")
# Save best model
# We keep the model with the lowest validation loss
if eval_loss < best_eval_loss:
best_eval_loss = eval_loss
model.save_pretrained(f"{config.output_dir}/best")
tokenizer.save_pretrained(f"{config.output_dir}/best")
print(f"Saved best model with eval loss: {eval_loss:.4f}")
return model
def evaluate(model, eval_loader, device):
"""Evaluate model on validation set.
This tells us how well the model generalizes to new data.
If training loss goes down but eval loss goes up... that's overfitting.
Bad news bears.
"""
model.eval() # Put model in evaluation mode (disables dropout, etc.)
total_loss = 0
# No gradients needed for evaluation - saves memory and time
with torch.no_grad():
for batch in eval_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels']
)
total_loss += outputs.loss.item()
model.train() # Put model back in training mode
return total_loss / len(eval_loader)
print("Training loop defined!")
print("\nKey concepts:")
print(" • Forward pass: Model makes predictions")
print(" • Backward pass: Calculate gradients (how to improve)")
print(" • Optimizer step: Actually update the weights")
print(" • Gradient accumulation: Simulate bigger batches")
print(" • Validation: Check if we're learning or just memorizing")Training loop defined!
Key concepts:
• Forward pass: Model makes predictions
• Backward pass: Calculate gradients (how to improve)
• Optimizer step: Actually update the weights
• Gradient accumulation: Simulate bigger batches
• Validation: Check if we're learning or just memorizing
Setting Up for Training¶
Alright, let’s load our model and data. This is the pre-flight checklist before takeoff.
We’ll use GPT-2 as our base model (it’s small enough to train quickly but still powerful), and the Alpaca dataset for training examples.
For this demo, we’re using a tiny subset of the data. In the real world, you’d use the full dataset and let it run for hours (days?). But we’re trying to learn here, not max out your electricity bill.
# Load model and tokenizer
model_name = "gpt2"
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# GPT-2 doesn't have a pad token by default, so we'll use the EOS token
# (This is a common trick - pad tokens are just for batching anyway)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
# Move model to GPU
model.to(device)
print(f"✓ Model loaded and moved to {device}")
# Load dataset
print("\nLoading Alpaca dataset...")
raw_data = load_dataset("yahma/alpaca-cleaned", split="train")
# Use a small subset for this demo
# (Training on 1000 examples instead of 52,000 - your GPU will thank me)
raw_data = raw_data.select(range(1000))
print(f"✓ Loaded {len(raw_data)} examples")
# Split into train/eval (90/10 split)
train_size = int(0.9 * len(raw_data))
train_data = raw_data.select(range(train_size))
eval_data = raw_data.select(range(train_size, len(raw_data)))
# Create datasets using our SFTDataset class
# We use shorter sequences (256) to save memory
train_dataset = SFTDataset(train_data, tokenizer, max_length=256)
eval_dataset = SFTDataset(eval_data, tokenizer, max_length=256)
print(f"✓ Created datasets:")
print(f" - Training: {len(train_dataset)} examples")
print(f" - Validation: {len(eval_dataset)} examples")
# Let's peek at one example to make sure everything looks right
sample = train_dataset[0]
print(f"\nSample training example:")
print(f" - Input IDs shape: {sample['input_ids'].shape}")
print(f" - Labels shape: {sample['labels'].shape}")
print(f" - Number of masked tokens: {(sample['labels'] == -100).sum().item()}")
print(f" - Number of trainable tokens: {(sample['labels'] != -100).sum().item()}")Loading gpt2...
✓ Model loaded and moved to cuda
Loading Alpaca dataset...
✓ Loaded 1000 examples
✓ Created datasets:
- Training: 900 examples
- Validation: 100 examples
Sample training example:
- Input IDs shape: torch.Size([256])
- Labels shape: torch.Size([256])
- Number of masked tokens: 106
- Number of trainable tokens: 150
# Adjust config for a quick demo run
config.num_epochs = 1
config.max_length = 256 # Shorter sequences = faster training
print("Starting training...")
print("=" * 60)
# This will take a few minutes
# Watch the loss decrease - that's learning in action!
model = train_sft(model, tokenizer, train_dataset, eval_dataset, config)
print("=" * 60)
print("✓ Training complete!")
print("\nWhat just happened?")
print(" • The model saw 900 training examples")
print(" • It adjusted its weights to minimize prediction errors")
print(" • We validated on 100 held-out examples to check generalization")
print(" • The best model was saved to disk")
print("\nNext up: Let's see if it actually learned anything!")Starting training...
============================================================
`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.
Epoch 1 - Train Loss: 2.5208, Eval Loss: 2.4195
Saved best model with eval loss: 2.4195
============================================================
✓ Training complete!
What just happened?
• The model saw 900 training examples
• It adjusted its weights to minimize prediction errors
• We validated on 100 held-out examples to check generalization
• The best model was saved to disk
Next up: Let's see if it actually learned anything!
Testing the Fine-Tuned Model¶
The proof is in the pudding. Let’s see if our model can actually follow instructions now.
We’ll give it a prompt and watch it generate a response. If training worked, it should follow the Alpaca format and give helpful answers.
If it starts rambling about random nonsense... well, that’s why we have validation loss. (And why we save the best checkpoint, not the final one.)
def generate_response(model, tokenizer, instruction, max_new_tokens=100):
"""Generate a response for an instruction.
This is the moment of truth - does the model actually follow instructions?
"""
# Format the prompt using our Alpaca template
prompt = ALPACA_TEMPLATE.format(instruction=instruction, response='')
# Tokenize and move to device
inputs = tokenizer(prompt, return_tensors='pt').to(device)
# Generate!
with torch.no_grad(): # No gradients needed for inference
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True, # Sampling for more natural text
temperature=0.7, # Higher = more random, lower = more conservative
top_p=0.9, # Nucleus sampling - keep top 90% probability mass
pad_token_id=tokenizer.pad_token_id
)
# Decode the generated tokens back to text
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the response part (after "### Response:")
response = response.split("### Response:")[-1].strip()
return response
# Test with a few different instructions
test_instructions = [
"Explain what machine learning is in simple terms.",
"Write a haiku about programming.",
"What are the three branches of the US government?"
]
print("Testing the fine-tuned model...")
print("=" * 60)
for instruction in test_instructions:
print(f"\nInstruction: {instruction}")
print("-" * 60)
response = generate_response(model, tokenizer, instruction)
print(f"Response: {response}")
print("=" * 60)
print("\nHow'd we do? The model should:")
print(" ✓ Follow the instruction format")
print(" ✓ Stay on topic")
print(" ✓ Generate coherent (if not always perfect) responses")
print("\nRemember: We only trained for one epoch on 900 examples.")
print("With more data and training time, it would get much better!")Testing the fine-tuned model...
============================================================
Instruction: Explain what machine learning is in simple terms.
------------------------------------------------------------
Response: Machine learning is a field in which the development of machine learning algorithms can be used to improve and improve their accuracy, efficiency, and performance. Machine learning algorithms are increasingly used by a wide range of industries, including consumer, office, and transportation. Machine learning is a popular tool that is used in the field of data analysis, machine learning, and machine learning.
Machine learning is the process of learning data and creating models, often by combining and analyzing data from a wide range of sources, often
============================================================
Instruction: Write a haiku about programming.
------------------------------------------------------------
Response: Programming is a branch of science and is often associated with the development of computer programs. It is often used to teach or study a specific skill or field, or to find a suitable subject for a particular job.
One of the most common uses of programming is to develop and improve computer programs. Programs are designed to provide a foundation of knowledge and provide tools to improve the computer's ability to perform tasks or perform tasks in a particular way. This means that programs can be implemented in a variety
============================================================
Instruction: What are the three branches of the US government?
------------------------------------------------------------
Response: The three branches of the US government are the United States, the United Kingdom, and Canada. The United States, in particular, is the United States, and the United Kingdom in general.
The United States is a country, not a country, that is governed by the Constitution and laws of the United States. The United Kingdom is a country that is governed by the laws of the United States.
The United Kingdom is a country that is governed by the laws of the United States,
============================================================
How'd we do? The model should:
✓ Follow the instruction format
✓ Stay on topic
✓ Generate coherent (if not always perfect) responses
Remember: We only trained for one epoch on 900 examples.
With more data and training time, it would get much better!
What We Learned¶
Congratulations! You just trained a transformer from scratch (well, fine-tuned one, but that still counts).
Let’s recap the key concepts:
The Training Loop is three steps repeated thousands of times:
Forward pass: Show the model data, get predictions
Backward pass: Calculate gradients (how wrong were we?)
Optimizer step: Update weights to be less wrong
Epochs are complete passes through the training data. More epochs = more learning, but also more risk of overfitting.
Gradient Accumulation lets us simulate bigger batch sizes without running out of memory. We accumulate gradients over multiple small batches, then update.
Learning Rate Scheduling adjusts how aggressively we update weights. Start with warmup (ease into it), then gradually decrease (make smaller adjustments as we get closer to optimal).
Validation Loss tells us if we’re actually learning or just memorizing. If it starts going up while training loss goes down, that’s overfitting.
Label Masking (the -100 trick) means we only train on the parts we care about (the responses), not the prompts.
Next up: LoRA - a clever way to fine-tune with way fewer parameters. Because not everyone has a server farm in their basement.