Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Training a Transformer: The Main Event

Alright. Deep breath.

We’ve learned about tokenization, formatting, and all the prep work. Now it’s time to actually train the thing.

We’re going to take a pre-trained model and teach it new tricks through supervised fine-tuning.

The Training Journey (A Roadmap)

Think of training as a road trip. You need:

  1. A vehicle (the pre-trained model and tokenizer)

  2. Fuel (the training data, properly formatted)

  3. A route (the data loader that feeds examples in batches)

  4. Navigation (the optimizer and learning rate scheduler)

  5. The actual driving (the training loop where learning happens)

  6. Rest stops (evaluation checkpoints)

  7. Your destination (the fine-tuned model, saved and ready to use)

Each step matters. Skip one and you’re stranded on the side of the road with a confused transformer.

Let’s build this piece by piece.

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np

# Setup device - GPU if available, CPU otherwise
# (Training on CPU is like walking to San Francisco. Technically possible, but...)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
Using device: cuda
GPU: Radeon RX 7900 XTX
Memory: 25.75 GB

The Dataset Class: Preparing Training Examples

Remember our Alpaca template from before? Now we need to wrap it in a PyTorch Dataset class.

Why? Because the training loop doesn’t want to deal with raw data. It wants a nice, standardized interface where it can say “give me example #42” and get back properly formatted tensors.

Think of this as a restaurant kitchen. The chef (training loop) doesn’t want to deal with whole chickens and raw vegetables. They want prepped ingredients, measured and ready to cook.

ALPACA_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
{response}"""

class SFTDataset(Dataset):
    """Dataset for supervised fine-tuning.
    
    This class does three critical things:
    1. Formats examples with the Alpaca template
    2. Tokenizes text into numbers the model understands
    3. Creates labels that mask the prompt (we only train on responses!)
    """
    
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        """How many examples do we have?"""
        return len(self.data)
    
    def __getitem__(self, idx):
        """Get one training example, fully prepped and ready."""
        item = self.data[idx]
        
        # Step 1: Format with our template
        formatted = ALPACA_TEMPLATE.format(
            instruction=item['instruction'],
            response=item['output']
        )
        
        # Step 2: Figure out where the response starts
        # We need this because we DON'T want to train on the instruction part
        # Only the response should contribute to the loss
        prompt = ALPACA_TEMPLATE.format(
            instruction=item['instruction'],
            response=''  # Empty response to find where it would start
        )
        
        # Step 3: Tokenize the full text
        full_tokens = self.tokenizer(
            formatted,
            max_length=self.max_length,
            truncation=True,  # Cut off if too long
            padding='max_length',  # Pad to consistent length
            return_tensors='pt'
        )
        
        # Step 4: Tokenize just the prompt to find response boundary
        prompt_tokens = self.tokenizer(
            prompt,
            max_length=self.max_length,
            truncation=True,
            return_tensors='pt'
        )
        
        response_start = prompt_tokens['input_ids'].shape[1]
        
        # Step 5: Create labels with masking
        # Here's the key insight: we copy the input_ids but mask the prompt
        # The -100 value tells PyTorch "don't calculate loss for these tokens"
        labels = full_tokens['input_ids'].clone().squeeze(0)
        labels[:response_start] = -100  # Mask the instruction
        
        # Also mask padding tokens (we don't want to train on padding!)
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        # Return everything the training loop needs
        return {
            'input_ids': full_tokens['input_ids'].squeeze(0),
            'attention_mask': full_tokens['attention_mask'].squeeze(0),
            'labels': labels
        }

# Let's test it with a sample
print("Dataset class defined successfully!")
print("\nKey insight: We mask the instruction with -100 so the model only learns")
print("to predict the response. This is supervised fine-tuning in action.")
Dataset class defined successfully!

Key insight: We mask the instruction with -100 so the model only learns
to predict the response. This is supervised fine-tuning in action.

Training Configuration: All the Knobs and Dials

Before we train, we need to make some decisions. How fast should we learn? How many examples at once? How many times through the data?

These are called hyperparameters (fancy name for “settings we tune by hand”). They can make or break your training run.

Too fast? The model goes haywire. Too slow? You’re waiting until the heat death of the universe. Too much data at once? Out of memory. Too little? Noisy, unstable training.

The settings below are reasonable defaults for our small demo. In the real world, you’d spend days (weeks?) tuning these.

from dataclasses import dataclass

@dataclass
class SFTConfig:
    """Configuration for SFT training.
    
    Let's break down what each setting does:
    """
    
    # Which model to start from
    model_name: str = "gpt2"
    
    # Maximum sequence length (longer = more memory)
    max_length: int = 512
    
    # Batch size: how many examples to process at once
    # Bigger = faster but needs more memory
    batch_size: int = 4
    
    # Learning rate: how big of steps to take when updating weights
    # This is the single most important hyperparameter
    learning_rate: float = 2e-4  # 0.0002 - small steps, safe and steady
    
    # Epochs: how many times to loop through the entire dataset
    # One epoch = seeing every training example once
    num_epochs: int = 3
    
    # Warmup steps: gradually increase learning rate at the start
    # Prevents the model from making crazy updates early on
    warmup_steps: int = 100
    
    # Gradient accumulation: simulate bigger batches without more memory
    # Process this many batches before updating weights
    # Effective batch size = batch_size * gradient_accumulation_steps
    gradient_accumulation_steps: int = 4
    
    # Gradient clipping: prevent exploding gradients
    # If gradients get bigger than this, scale them down
    max_grad_norm: float = 1.0
    
    # How often to log progress
    logging_steps: int = 10
    
    # How often to check validation loss
    eval_steps: int = 100
    
    # How often to save checkpoints
    save_steps: int = 500
    
    # Where to save the model
    output_dir: str = "./sft_output"

config = SFTConfig()

print("Training configuration:")
print("=" * 60)
for k, v in vars(config).items():
    print(f"  {k:.<35} {v}")
print("=" * 60)
print(f"\nEffective batch size: {config.batch_size * config.gradient_accumulation_steps}")
print("(That's how many examples we process before each weight update)")
Training configuration:
============================================================
  model_name......................... gpt2
  max_length......................... 512
  batch_size......................... 4
  learning_rate...................... 0.0002
  num_epochs......................... 3
  warmup_steps....................... 100
  gradient_accumulation_steps........ 4
  max_grad_norm...................... 1.0
  logging_steps...................... 10
  eval_steps......................... 100
  save_steps......................... 500
  output_dir......................... ./sft_output
============================================================

Effective batch size: 16
(That's how many examples we process before each weight update)

The Training Loop

Okay, here’s the heart of it all. The training loop.

This is where we actually teach the model. It’s a dance with three steps, repeated thousands of times:

  1. Forward pass: Show the model an example, see what it predicts

  2. Backward pass: Calculate how wrong it was, compute gradients

  3. Update: Adjust the weights to be a little less wrong next time

Rinse and repeat until the model gets good (or you run out of patience/GPU credits).

The code below looks long, but it’s just those three steps wrapped in careful bookkeeping. Let me walk you through it.

def train_sft(model, tokenizer, train_dataset, eval_dataset, config):
    """Complete SFT training loop.
    
    This function orchestrates the entire training process.
    It's like conducting an orchestra - lots of moving parts, all needing
    to work together in harmony.
    """
    
    # Create data loaders
    # These batch up our data and shuffle it each epoch
    # Think of this as the assembly line that feeds examples to the model
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,  # Randomize order each epoch - helps training
        num_workers=0  # Parallel data loading (0 = use main process)
    )
    
    eval_loader = DataLoader(
        eval_dataset,
        batch_size=config.batch_size,
        shuffle=False  # No need to shuffle validation data
    )
    
    # Setup optimizer
    # The optimizer is what actually updates the model weights
    # AdamW is the gold standard for transformer training
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=0.01  # Regularization to prevent overfitting
    )
    
    # Learning rate scheduler
    # We don't use a fixed learning rate - we adjust it over time
    # Start with warmup (gradual increase), then linear decay
    total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=total_steps
    )
    
    # Training state
    model.train()  # Put model in training mode (enables dropout, etc.)
    global_step = 0  # Total number of weight updates
    best_eval_loss = float('inf')  # Track best model
    
    # The main training loop - iterate over epochs
    # An epoch = one complete pass through the training data
    for epoch in range(config.num_epochs):
        epoch_loss = 0
        
        # Progress bar for visual feedback (because watching loss decrease is satisfying)
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
        
        # Iterate over batches within this epoch
        for step, batch in enumerate(progress_bar):
            # Move batch to GPU (if available)
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # ===== FORWARD PASS =====
            # Feed the input through the model, get predictions and loss
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']  # The model computes loss for us!
            )
            
            # Scale loss by accumulation steps
            # Why? Because we're going to accumulate gradients across multiple batches
            loss = outputs.loss / config.gradient_accumulation_steps
            
            # ===== BACKWARD PASS =====
            # Compute gradients - this is where the learning happens
            # PyTorch automagically calculates how to adjust each weight
            loss.backward()
            
            epoch_loss += loss.item() * config.gradient_accumulation_steps
            
            # ===== UPDATE WEIGHTS =====
            # Only update every gradient_accumulation_steps
            # This simulates a larger batch size without using more memory
            if (step + 1) % config.gradient_accumulation_steps == 0:
                # Clip gradients to prevent explosions
                # Sometimes gradients get REALLY big - this caps them
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    config.max_grad_norm
                )
                
                # Take the optimization step
                optimizer.step()
                scheduler.step()  # Update learning rate
                optimizer.zero_grad()  # Clear gradients for next batch
                
                global_step += 1
                
                # Log progress
                if global_step % config.logging_steps == 0:
                    avg_loss = epoch_loss / (step + 1)
                    progress_bar.set_postfix({
                        'loss': f'{avg_loss:.4f}',
                        'lr': f'{scheduler.get_last_lr()[0]:.2e}'
                    })
        
        # End of epoch - check how we're doing on validation data
        eval_loss = evaluate(model, eval_loader, device)
        print(f"\nEpoch {epoch+1} - Train Loss: {epoch_loss/len(train_loader):.4f}, Eval Loss: {eval_loss:.4f}")
        
        # Save best model
        # We keep the model with the lowest validation loss
        if eval_loss < best_eval_loss:
            best_eval_loss = eval_loss
            model.save_pretrained(f"{config.output_dir}/best")
            tokenizer.save_pretrained(f"{config.output_dir}/best")
            print(f"Saved best model with eval loss: {eval_loss:.4f}")
    
    return model


def evaluate(model, eval_loader, device):
    """Evaluate model on validation set.
    
    This tells us how well the model generalizes to new data.
    If training loss goes down but eval loss goes up... that's overfitting.
    Bad news bears.
    """
    model.eval()  # Put model in evaluation mode (disables dropout, etc.)
    total_loss = 0
    
    # No gradients needed for evaluation - saves memory and time
    with torch.no_grad():
        for batch in eval_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            total_loss += outputs.loss.item()
    
    model.train()  # Put model back in training mode
    return total_loss / len(eval_loader)

print("Training loop defined!")
print("\nKey concepts:")
print("  • Forward pass: Model makes predictions")
print("  • Backward pass: Calculate gradients (how to improve)")
print("  • Optimizer step: Actually update the weights")
print("  • Gradient accumulation: Simulate bigger batches")
print("  • Validation: Check if we're learning or just memorizing")
Training loop defined!

Key concepts:
  • Forward pass: Model makes predictions
  • Backward pass: Calculate gradients (how to improve)
  • Optimizer step: Actually update the weights
  • Gradient accumulation: Simulate bigger batches
  • Validation: Check if we're learning or just memorizing

Setting Up for Training

Alright, let’s load our model and data. This is the pre-flight checklist before takeoff.

We’ll use GPT-2 as our base model (it’s small enough to train quickly but still powerful), and the Alpaca dataset for training examples.

For this demo, we’re using a tiny subset of the data. In the real world, you’d use the full dataset and let it run for hours (days?). But we’re trying to learn here, not max out your electricity bill.

# Load model and tokenizer
model_name = "gpt2"
print(f"Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# GPT-2 doesn't have a pad token by default, so we'll use the EOS token
# (This is a common trick - pad tokens are just for batching anyway)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id

# Move model to GPU
model.to(device)
print(f"✓ Model loaded and moved to {device}")

# Load dataset
print("\nLoading Alpaca dataset...")
raw_data = load_dataset("yahma/alpaca-cleaned", split="train")

# Use a small subset for this demo
# (Training on 1000 examples instead of 52,000 - your GPU will thank me)
raw_data = raw_data.select(range(1000))
print(f"✓ Loaded {len(raw_data)} examples")

# Split into train/eval (90/10 split)
train_size = int(0.9 * len(raw_data))
train_data = raw_data.select(range(train_size))
eval_data = raw_data.select(range(train_size, len(raw_data)))

# Create datasets using our SFTDataset class
# We use shorter sequences (256) to save memory
train_dataset = SFTDataset(train_data, tokenizer, max_length=256)
eval_dataset = SFTDataset(eval_data, tokenizer, max_length=256)

print(f"✓ Created datasets:")
print(f"  - Training: {len(train_dataset)} examples")
print(f"  - Validation: {len(eval_dataset)} examples")

# Let's peek at one example to make sure everything looks right
sample = train_dataset[0]
print(f"\nSample training example:")
print(f"  - Input IDs shape: {sample['input_ids'].shape}")
print(f"  - Labels shape: {sample['labels'].shape}")
print(f"  - Number of masked tokens: {(sample['labels'] == -100).sum().item()}")
print(f"  - Number of trainable tokens: {(sample['labels'] != -100).sum().item()}")
Loading gpt2...
✓ Model loaded and moved to cuda

Loading Alpaca dataset...
✓ Loaded 1000 examples
✓ Created datasets:
  - Training: 900 examples
  - Validation: 100 examples

Sample training example:
  - Input IDs shape: torch.Size([256])
  - Labels shape: torch.Size([256])
  - Number of masked tokens: 106
  - Number of trainable tokens: 150
# Adjust config for a quick demo run
config.num_epochs = 1
config.max_length = 256  # Shorter sequences = faster training

print("Starting training...")
print("=" * 60)

# This will take a few minutes
# Watch the loss decrease - that's learning in action!
model = train_sft(model, tokenizer, train_dataset, eval_dataset, config)

print("=" * 60)
print("✓ Training complete!")
print("\nWhat just happened?")
print("  • The model saw 900 training examples")
print("  • It adjusted its weights to minimize prediction errors")
print("  • We validated on 100 held-out examples to check generalization")
print("  • The best model was saved to disk")
print("\nNext up: Let's see if it actually learned anything!")
Starting training...
============================================================
Loading...
`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.

Epoch 1 - Train Loss: 2.5208, Eval Loss: 2.4195
Saved best model with eval loss: 2.4195
============================================================
✓ Training complete!

What just happened?
  • The model saw 900 training examples
  • It adjusted its weights to minimize prediction errors
  • We validated on 100 held-out examples to check generalization
  • The best model was saved to disk

Next up: Let's see if it actually learned anything!

Testing the Fine-Tuned Model

The proof is in the pudding. Let’s see if our model can actually follow instructions now.

We’ll give it a prompt and watch it generate a response. If training worked, it should follow the Alpaca format and give helpful answers.

If it starts rambling about random nonsense... well, that’s why we have validation loss. (And why we save the best checkpoint, not the final one.)

def generate_response(model, tokenizer, instruction, max_new_tokens=100):
    """Generate a response for an instruction.
    
    This is the moment of truth - does the model actually follow instructions?
    """
    # Format the prompt using our Alpaca template
    prompt = ALPACA_TEMPLATE.format(instruction=instruction, response='')
    
    # Tokenize and move to device
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    
    # Generate!
    with torch.no_grad():  # No gradients needed for inference
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,  # Sampling for more natural text
            temperature=0.7,  # Higher = more random, lower = more conservative
            top_p=0.9,  # Nucleus sampling - keep top 90% probability mass
            pad_token_id=tokenizer.pad_token_id
        )
    
    # Decode the generated tokens back to text
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract just the response part (after "### Response:")
    response = response.split("### Response:")[-1].strip()
    
    return response

# Test with a few different instructions
test_instructions = [
    "Explain what machine learning is in simple terms.",
    "Write a haiku about programming.",
    "What are the three branches of the US government?"
]

print("Testing the fine-tuned model...")
print("=" * 60)

for instruction in test_instructions:
    print(f"\nInstruction: {instruction}")
    print("-" * 60)
    response = generate_response(model, tokenizer, instruction)
    print(f"Response: {response}")
    print("=" * 60)

print("\nHow'd we do? The model should:")
print("  ✓ Follow the instruction format")
print("  ✓ Stay on topic")
print("  ✓ Generate coherent (if not always perfect) responses")
print("\nRemember: We only trained for one epoch on 900 examples.")
print("With more data and training time, it would get much better!")
Testing the fine-tuned model...
============================================================

Instruction: Explain what machine learning is in simple terms.
------------------------------------------------------------
Response: Machine learning is a field in which the development of machine learning algorithms can be used to improve and improve their accuracy, efficiency, and performance. Machine learning algorithms are increasingly used by a wide range of industries, including consumer, office, and transportation. Machine learning is a popular tool that is used in the field of data analysis, machine learning, and machine learning.

Machine learning is the process of learning data and creating models, often by combining and analyzing data from a wide range of sources, often
============================================================

Instruction: Write a haiku about programming.
------------------------------------------------------------
Response: Programming is a branch of science and is often associated with the development of computer programs. It is often used to teach or study a specific skill or field, or to find a suitable subject for a particular job.

One of the most common uses of programming is to develop and improve computer programs. Programs are designed to provide a foundation of knowledge and provide tools to improve the computer's ability to perform tasks or perform tasks in a particular way. This means that programs can be implemented in a variety
============================================================

Instruction: What are the three branches of the US government?
------------------------------------------------------------
Response: The three branches of the US government are the United States, the United Kingdom, and Canada. The United States, in particular, is the United States, and the United Kingdom in general.

The United States is a country, not a country, that is governed by the Constitution and laws of the United States. The United Kingdom is a country that is governed by the laws of the United States.

The United Kingdom is a country that is governed by the laws of the United States,
============================================================

How'd we do? The model should:
  ✓ Follow the instruction format
  ✓ Stay on topic
  ✓ Generate coherent (if not always perfect) responses

Remember: We only trained for one epoch on 900 examples.
With more data and training time, it would get much better!

What We Learned

Congratulations! You just trained a transformer from scratch (well, fine-tuned one, but that still counts).

Let’s recap the key concepts:

The Training Loop is three steps repeated thousands of times:

  • Forward pass: Show the model data, get predictions

  • Backward pass: Calculate gradients (how wrong were we?)

  • Optimizer step: Update weights to be less wrong

Epochs are complete passes through the training data. More epochs = more learning, but also more risk of overfitting.

Gradient Accumulation lets us simulate bigger batch sizes without running out of memory. We accumulate gradients over multiple small batches, then update.

Learning Rate Scheduling adjusts how aggressively we update weights. Start with warmup (ease into it), then gradually decrease (make smaller adjustments as we get closer to optimal).

Validation Loss tells us if we’re actually learning or just memorizing. If it starts going up while training loss goes down, that’s overfitting.

Label Masking (the -100 trick) means we only train on the parts we care about (the responses), not the prompts.


Next up: LoRA - a clever way to fine-tune with way fewer parameters. Because not everyone has a server farm in their basement.