Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Training a Transformer

We’ve learned about tokenization, formatting, and all the prep work. Now it’s time to actually train the thing.

We’re going to take a pre-trained model and teach it new tricks through supervised fine-tuning.

The Training Journey (A Roadmap)

Think of training as a road trip. You need:

  1. A vehicle (the pre-trained model and tokenizer)

  2. Fuel (the training data, properly formatted)

  3. A route (the data loader that feeds examples in batches)

  4. Navigation (the optimizer and learning rate scheduler)

  5. The actual driving (the training loop where learning happens)

  6. Rest stops (evaluation checkpoints)

  7. Your destination (the fine-tuned model, saved and ready to use)

Each step matters. Skip one and you’re stranded on the side of the road with a confused transformer and a mixed up metaphor.

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np

# Setup device - GPU if available, CPU otherwise
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

if device.type == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
Using device: cuda
GPU: NVIDIA GeForce RTX 5090
Memory: 34.19 GB

The Dataset Class: Preparing Training Examples

Remember our Alpaca template from before? Now we need to wrap it in a PyTorch Dataset class.

ALPACA_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
{response}"""

class SFTDataset(Dataset):
    """Dataset for supervised fine-tuning.
    
    This class does three critical things:
    1. Formats examples with the Alpaca template
    2. Tokenizes text into numbers the model understands
    3. Creates labels that mask the prompt (we only train on responses!)
    """
    
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        """How many examples do we have?"""
        return len(self.data)
    
    def __getitem__(self, idx):
        """Get one training example, fully prepped and ready."""
        item = self.data[idx]
        
        # Step 1: Format with our template
        formatted = ALPACA_TEMPLATE.format(
            instruction=item['instruction'],
            response=item['output']
        )
        
        # Step 2: Figure out where the response starts
        # We need this because we DON'T want to train on the instruction part
        # Only the response should contribute to the loss
        prompt = ALPACA_TEMPLATE.format(
            instruction=item['instruction'],
            response=''  # Empty response to find where it would start
        )
        
        # Step 3: Tokenize the full text
        full_tokens = self.tokenizer(
            formatted,
            max_length=self.max_length,
            truncation=True,  # Cut off if too long
            padding='max_length',  # Pad to consistent length
            return_tensors='pt'
        )
        
        # Step 4: Tokenize just the prompt to find response boundary
        prompt_tokens = self.tokenizer(
            prompt,
            max_length=self.max_length,
            truncation=True,
            return_tensors='pt'
        )
        
        response_start = prompt_tokens['input_ids'].shape[1]
        
        # Step 5: Create labels with masking
        # Here's the key insight: we copy the input_ids but mask the prompt
        # The -100 value tells PyTorch "don't calculate loss for these tokens"
        labels = full_tokens['input_ids'].clone().squeeze(0)
        labels[:response_start] = -100  # Mask the instruction
        
        # Also mask padding tokens (we don't want to train on padding!)
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        # Return everything the training loop needs
        return {
            'input_ids': full_tokens['input_ids'].squeeze(0),
            'attention_mask': full_tokens['attention_mask'].squeeze(0),
            'labels': labels
        }

# Let's test it with a sample
print("Dataset class defined successfully!")
print("\nKey insight: We mask the instruction with -100 so the model only learns")
print("to predict the response. This is supervised fine-tuning in action.")
Dataset class defined successfully!

Key insight: We mask the instruction with -100 so the model only learns
to predict the response. This is supervised fine-tuning in action.

Training Configuration

Before we train, we need to make some decisions. How fast should we learn? How many examples at once? How many times through the data?

These are called hyperparameters.

The settings below are reasonable defaults for our small demo. Probably.

from dataclasses import dataclass

@dataclass
class SFTConfig:
    """Configuration for SFT training.
    
    Let's break down what each setting does:
    """
    
    # Which model to start from
    model_name: str = "gpt2"
    
    # Maximum sequence length (longer = more memory)
    max_length: int = 512
    
    # Batch size: how many examples to process at once
    # Bigger = faster but needs more memory
    batch_size: int = 4
    
    # Learning rate: how big of steps to take when updating weights
    # This is the single most important hyperparameter
    learning_rate: float = 2e-4  # 0.0002 - small steps, safe and steady
    
    # Epochs: how many times to loop through the entire dataset
    # One epoch = seeing every training example once
    num_epochs: int = 3
    
    # Warmup steps: gradually increase learning rate at the start
    # Prevents the model from making crazy updates early on
    warmup_steps: int = 100
    
    # Gradient accumulation: simulate bigger batches without more memory
    # Process this many batches before updating weights
    # Effective batch size = batch_size * gradient_accumulation_steps
    gradient_accumulation_steps: int = 4
    
    # Gradient clipping: prevent exploding gradients
    # If gradients get bigger than this, scale them down
    max_grad_norm: float = 1.0
    
    # How often to log progress
    logging_steps: int = 10
    
    # How often to check validation loss
    eval_steps: int = 100
    
    # How often to save checkpoints
    save_steps: int = 500
    
    # Where to save the model
    output_dir: str = "./sft_output"

config = SFTConfig()

print("Training configuration:")
print("=" * 60)
for k, v in vars(config).items():
    print(f"  {k:.<35} {v}")
print("=" * 60)
print(f"\nEffective batch size: {config.batch_size * config.gradient_accumulation_steps}")
print("(That's how many examples we process before each weight update)")
Training configuration:
============================================================
  model_name......................... gpt2
  max_length......................... 512
  batch_size......................... 4
  learning_rate...................... 0.0002
  num_epochs......................... 3
  warmup_steps....................... 100
  gradient_accumulation_steps........ 4
  max_grad_norm...................... 1.0
  logging_steps...................... 10
  eval_steps......................... 100
  save_steps......................... 500
  output_dir......................... ./sft_output
============================================================

Effective batch size: 16
(That's how many examples we process before each weight update)

The Training Loop

This is where we actually teach the model. It has three steps, repeated thousands of times:

  1. Forward pass: Show the model an example, see what it predicts

  2. Backward pass: Calculate how wrong it was, compute gradients

  3. Update: Adjust the weights to be a little less wrong next time

Rinse and repeat until the model gets good (or you run out of patience/GPU credits).

The code below looks long, but it’s just those three steps wrapped in careful bookkeeping.

def train_sft(model, tokenizer, train_dataset, eval_dataset, config):
    """Complete SFT training loop.
    
    This function orchestrates the entire training process.
    """
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0
    )
    
    eval_loader = DataLoader(
        eval_dataset,
        batch_size=config.batch_size,
        shuffle=False
    )
    
    # Setup optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=0.01
    )
    
    # Learning rate scheduler
    total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=total_steps
    )
    
    # Training state
    model.train()
    global_step = 0
    best_eval_loss = float('inf')
    
    for epoch in range(config.num_epochs):
        epoch_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
        
        for step, batch in enumerate(progress_bar):
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            
            loss = outputs.loss / config.gradient_accumulation_steps
            
            # Backward pass
            loss.backward()
            
            epoch_loss += loss.item() * config.gradient_accumulation_steps
            
            # Update weights
            if (step + 1) % config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    config.max_grad_norm
                )
                
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                
                global_step += 1
                
                if global_step % config.logging_steps == 0:
                    avg_loss = epoch_loss / (step + 1)
                    progress_bar.set_postfix({
                        'loss': f'{avg_loss:.4f}',
                        'lr': f'{scheduler.get_last_lr()[0]:.2e}'
                    })
        
        # End of epoch - evaluate
        eval_loss = evaluate(model, eval_loader, device)
        print(f"\nEpoch {epoch+1} - Train Loss: {epoch_loss/len(train_loader):.4f}, Eval Loss: {eval_loss:.4f}")
        
        # Save best model
        if eval_loss < best_eval_loss:
            best_eval_loss = eval_loss
            model.save_pretrained(f"{config.output_dir}/best")
            tokenizer.save_pretrained(f"{config.output_dir}/best")
            print(f"Saved best model with eval loss: {eval_loss:.4f}")
    
    return model


def evaluate(model, eval_loader, device):
    """Evaluate model on validation set.
    
    This tells us how well the model generalizes to new data.
    If training loss goes down but eval loss goes up, that's overfitting.
    """
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in eval_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            total_loss += outputs.loss.item()
    
    model.train()
    return total_loss / len(eval_loader)

print("Training loop defined.")
print("\nKey concepts:")
print("  • Forward pass: Model makes predictions")
print("  • Backward pass: Calculate gradients")
print("  • Optimizer step: Update the weights")
print("  • Gradient accumulation: Simulate bigger batches")
print("  • Validation: Check if we're learning or just memorizing")
Training loop defined.

Key concepts:
  • Forward pass: Model makes predictions
  • Backward pass: Calculate gradients
  • Optimizer step: Update the weights
  • Gradient accumulation: Simulate bigger batches
  • Validation: Check if we're learning or just memorizing

Setting Up for Training

Alright, let’s load our model and data. This is the pre-flight checklist before takeoff.

We’ll use GPT-2 as our base model (it’s small enough to train quickly but still powerful), and the Alpaca dataset for training examples.

For this demo, we’re using a tiny subset of the data. In the real world, you’d use the full dataset and let it run for hours (days?).

# Load model and tokenizer
model_name = "gpt2"
print(f"Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# GPT-2 doesn't have a pad token by default, so we'll use the EOS token
# (This is a common trick - pad tokens are just for batching anyway)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id

# Move model to GPU
model.to(device)
print(f"✓ Model loaded and moved to {device}")

# Load dataset
print("\nLoading Alpaca dataset...")
raw_data = load_dataset("yahma/alpaca-cleaned", split="train")

# Use a small subset for this demo
# (Training on 1000 examples instead of 52,000 - your GPU will thank me)
raw_data = raw_data.select(range(1000))
print(f"✓ Loaded {len(raw_data)} examples")

# Split into train/eval (90/10 split)
train_size = int(0.9 * len(raw_data))
train_data = raw_data.select(range(train_size))
eval_data = raw_data.select(range(train_size, len(raw_data)))

# Create datasets using our SFTDataset class
# We use shorter sequences (256) to save memory
train_dataset = SFTDataset(train_data, tokenizer, max_length=256)
eval_dataset = SFTDataset(eval_data, tokenizer, max_length=256)

print(f"✓ Created datasets:")
print(f"  - Training: {len(train_dataset)} examples")
print(f"  - Validation: {len(eval_dataset)} examples")

# Let's peek at one example to make sure everything looks right
sample = train_dataset[0]
print(f"\nSample training example:")
print(f"  - Input IDs shape: {sample['input_ids'].shape}")
print(f"  - Labels shape: {sample['labels'].shape}")
print(f"  - Number of masked tokens: {(sample['labels'] == -100).sum().item()}")
print(f"  - Number of trainable tokens: {(sample['labels'] != -100).sum().item()}")
Loading gpt2...
✓ Model loaded and moved to cuda

Loading Alpaca dataset...
✓ Loaded 1000 examples
✓ Created datasets:
  - Training: 900 examples
  - Validation: 100 examples

Sample training example:
  - Input IDs shape: torch.Size([256])
  - Labels shape: torch.Size([256])
  - Number of masked tokens: 106
  - Number of trainable tokens: 150
# Adjust config for a quick demo run
config.num_epochs = 1
config.max_length = 256

print("Starting training...")
print("=" * 60)

model = train_sft(model, tokenizer, train_dataset, eval_dataset, config)

print("=" * 60)
print("Training complete.")
print("\nWhat happened:")
print("  • The model saw 900 training examples")
print("  • It adjusted its weights to minimize prediction errors")
print("  • We validated on 100 held-out examples to check generalization")
print("  • The best model was saved to disk")
Starting training...
============================================================
Loading...
`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.

Epoch 1 - Train Loss: 2.5210, Eval Loss: 2.4244
Saved best model with eval loss: 2.4244
============================================================
Training complete.

What happened:
  • The model saw 900 training examples
  • It adjusted its weights to minimize prediction errors
  • We validated on 100 held-out examples to check generalization
  • The best model was saved to disk

Testing the Fine-Tuned Model

Let’s see if our model can actually follow instructions now.

We’ll give it a prompt and watch it generate a response. If training worked, it should follow the Alpaca format and give helpful answers.

def generate_response(model, tokenizer, instruction, max_new_tokens=100):
    """Generate a response for an instruction."""
    prompt = ALPACA_TEMPLATE.format(instruction=instruction, response='')
    
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.split("### Response:")[-1].strip()
    
    return response

# Test with a few different instructions
test_instructions = [
    "Explain what machine learning is in simple terms.",
    "Write a haiku about programming.",
    "What are the three branches of the US government?"
]

print("Testing the fine-tuned model...")
print("=" * 60)

for instruction in test_instructions:
    print(f"\nInstruction: {instruction}")
    print("-" * 60)
    response = generate_response(model, tokenizer, instruction)
    print(f"Response: {response}")
    print("=" * 60)

print("\nThe model should:")
print("  ✓ Follow the instruction format")
print("  ✓ Stay on topic")
print("  ✓ Generate coherent responses")
print("\nRemember: We only trained for one epoch on 900 examples.")
print("With more data and training time, it would get much better.")
Testing the fine-tuned model...
============================================================

Instruction: Explain what machine learning is in simple terms.
------------------------------------------------------------
Response: The main goal of machine learning is to provide a broad range of machine learning algorithms, and to provide a variety of potential applications for machine learning.

Machine learning is a computational process that uses the power of natural language processing to learn and interpret text and images in complex, dynamic, and complex environments. It is a powerful and versatile way to learn and interpret text and images in complex, dynamic, and complex environments.

Machine learning is a complex and challenging field that requires a diverse range of
============================================================

Instruction: Write a haiku about programming.
------------------------------------------------------------
Response: Programming the Future is an ambitious project that uses technology to create a new generation of programmers and their skills. The goal of this project is to create a new generation of programmers by using technology to create software that can be written and used in a wide range of industries.

### Overview:
The goal of this project is to create a new generation of programmers by using technology to create software that can be written and used in a wide range of industries.

The goal of this project is
============================================================

Instruction: What are the three branches of the US government?
------------------------------------------------------------
Response: The three branches of the US government are:

1. The Federal Bureau of Investigation (FBI)
2. The Department of Homeland Security (DHS)
3. The Department of Energy (DOE)
4. The Department of Energy (DOE)
5. The Department of Transportation (DOT)
6. The Environmental Protection Agency (EPA)
7. The Federal Communications Commission (FCC)
8. The Federal Deposit Insurance Corporation (FDIC)
============================================================

The model should:
  ✓ Follow the instruction format
  ✓ Stay on topic
  ✓ Generate coherent responses

Remember: We only trained for one epoch on 900 examples.
With more data and training time, it would get much better.