Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

The Complete Model

Putting It All Together

We’ve built every piece. Now let’s assemble them into a complete, working language model.

Data Flow Through the Model

Token IDs: [101, 2054, 2003, 2115, 2171]
                    ↓
┌─────────────────────────────────────┐
│       Token Embedding               │  Look up vectors
│       + Position Embedding          │  Add position info
│       + Dropout                     │
└─────────────────────────────────────┘
                    ↓
┌─────────────────────────────────────┐
│     Transformer Block × N           │  The core processing
│                                     │
│  (attention + FFN, with residuals)  │
└─────────────────────────────────────┘
                    ↓
┌─────────────────────────────────────┐
│       Final LayerNorm               │  Stabilize outputs
└─────────────────────────────────────┘
                    ↓
┌─────────────────────────────────────┐
│       Output Projection             │  d_model → vocab_size
└─────────────────────────────────────┘
                    ↓
Logits: [0.1, 0.3, 8.2, ...]           Scores for each token

Five stages:

  1. Embedding: Convert token IDs to vectors, add position information

  2. Transformer Blocks: N layers of attention + FFN (we use 4; GPT-3 uses 96)

  3. Final LayerNorm: One last normalization for stability

  4. Output Projection: Map from dmodeld_{model} to vocabulary size

  5. Output Logits: Raw scores for each possible next token

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# Components from previous notebooks

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        context = torch.matmul(attn, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.W_o(context)


class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.linear2(F.gelu(self.linear1(x))))


class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attn_norm = nn.LayerNorm(d_model)
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn_norm = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        x = x + self.dropout(self.attention(self.attn_norm(x), mask))
        x = x + self.dropout(self.ffn(self.ffn_norm(x)))
        return x

The Complete GPT Model

class GPT(nn.Module):
    """
    A complete decoder-only transformer language model.
    
    This is the same architecture used by GPT-2, GPT-3, and Claude,
    just smaller. The fundamentals are identical.
    """
    
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 256,
        num_heads: int = 4,
        num_layers: int = 4,
        d_ff: int = 1024,
        max_seq_len: int = 512,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.d_model = d_model
        self.vocab_size = vocab_size
        
        # Token and position embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        self.embed_dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Final layer norm
        self.ln_f = nn.LayerNorm(d_model)
        
        # Output projection: d_model → vocab_size
        self.output_proj = nn.Linear(d_model, vocab_size, bias=False)
        
        # Weight tying: share weights between embedding and output
        # This is a common technique that reduces parameters and improves quality
        self.output_proj.weight = self.token_embedding.weight
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights with small random values."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(
        self, 
        token_ids: torch.Tensor,
        mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            token_ids: (batch, seq_len) - integer token IDs
            mask: optional causal mask
        
        Returns:
            logits: (batch, seq_len, vocab_size) - scores for each token
        """
        batch_size, seq_len = token_ids.shape
        device = token_ids.device
        
        # Create causal mask if not provided
        if mask is None:
            mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
        
        # 1. Embeddings
        positions = torch.arange(seq_len, device=device)
        x = self.token_embedding(token_ids) + self.position_embedding(positions)
        x = self.embed_dropout(x)
        
        # 2. Transformer blocks
        for block in self.blocks:
            x = block(x, mask)
        
        # 3. Final layer norm
        x = self.ln_f(x)
        
        # 4. Project to vocabulary
        logits = self.output_proj(x)
        
        return logits
# Create the model
model = GPT(
    vocab_size=10000,
    d_model=256,
    num_heads=4,
    num_layers=4,
    d_ff=1024,
    max_seq_len=512,
    dropout=0.1
)

print("GPT Model Configuration")
print("=" * 40)
print(f"  vocab_size:   {model.vocab_size:,}")
print(f"  d_model:      {model.d_model}")
print(f"  num_heads:    4")
print(f"  num_layers:   4")
print(f"  d_ff:         1024 (4 × d_model)")
print(f"  max_seq_len:  512")
GPT Model Configuration
========================================
  vocab_size:   10,000
  d_model:      256
  num_heads:    4
  num_layers:   4
  d_ff:         1024 (4 × d_model)
  max_seq_len:  512

Counting Every Parameter

Let’s count exactly where all the parameters come from:

def count_params(module):
    return sum(p.numel() for p in module.parameters())

# Manual calculation
vocab_size = 10000
d_model = 256
num_heads = 4
num_layers = 4
d_ff = 1024
max_seq_len = 512

print("Parameter Breakdown")
print("=" * 60)

# Embeddings
token_embed = vocab_size * d_model
pos_embed = max_seq_len * d_model
print(f"Token embeddings:      {vocab_size} × {d_model} = {token_embed:>12,}")
print(f"Position embeddings:   {max_seq_len} × {d_model} = {pos_embed:>12,}")

# Per block
attn_params = 4 * d_model * d_model  # Q, K, V, O
ffn_params = 2 * d_model * d_ff + d_ff + d_model  # up + down + biases
ln_params = 4 * d_model  # 2 layer norms × (gamma + beta)
block_params = attn_params + ffn_params + ln_params

print(f"\nPer transformer block:")
print(f"  Attention (Q,K,V,O):  4 × {d_model}² = {attn_params:>12,}")
print(f"  FFN:                  {ffn_params:>12,}")
print(f"  LayerNorm (×2):       {ln_params:>12,}")
print(f"  Block total:          {block_params:>12,}")

# All blocks
all_blocks = num_layers * block_params
print(f"\nAll {num_layers} blocks:          {all_blocks:>12,}")

# Final LN
final_ln = 2 * d_model
print(f"Final LayerNorm:        {final_ln:>12,}")

# Output projection (tied with token embedding - no additional params)
print(f"Output projection:      (tied with token embeddings)")

# Total
total_calculated = token_embed + pos_embed + all_blocks + final_ln
total_actual = count_params(model)

print("=" * 60)
print(f"Total (calculated):     {total_calculated:>12,}")
print(f"Total (actual):         {total_actual:>12,}")
Parameter Breakdown
============================================================
Token embeddings:      10000 × 256 =    2,560,000
Position embeddings:   512 × 256 =      131,072

Per transformer block:
  Attention (Q,K,V,O):  4 × 256² =      262,144
  FFN:                       525,568
  LayerNorm (×2):              1,024
  Block total:               788,736

All 4 blocks:             3,154,944
Final LayerNorm:                 512
Output projection:      (tied with token embeddings)
============================================================
Total (calculated):        5,846,528
Total (actual):            5,846,528

Testing the Forward Pass

# Generate random token IDs (simulating tokenized text)
batch_size = 2
seq_len = 16
tokens = torch.randint(0, vocab_size, (batch_size, seq_len))

# Forward pass
logits = model(tokens)

print(f"Input tokens shape:  {tokens.shape}")
print(f"Output logits shape: {logits.shape}")
print(f"\nFor each position, we get {vocab_size} scores (one per vocabulary token).")
Input tokens shape:  torch.Size([2, 16])
Output logits shape: torch.Size([2, 16, 10000])

For each position, we get 10000 scores (one per vocabulary token).

What Are Logits?

The model outputs logits: raw, unnormalized scores for each token in the vocabulary. Higher scores mean the model thinks that token is more likely to come next.

To convert logits to probabilities, apply softmax:

P(tokeni)=elogitijelogitjP(\text{token}_i) = \frac{e^{\text{logit}_i}}{\sum_j e^{\text{logit}_j}}

Then you can either:

  • Greedy decoding: Pick the highest-probability token

  • Sampling: Randomly sample from the distribution (for variety)

  • Top-k/Top-p sampling: Sample from the most likely tokens

# Convert logits to probabilities
last_position_logits = logits[0, -1, :]  # Last position of first batch
probs = F.softmax(last_position_logits, dim=-1)

# Top 5 predictions
top_probs, top_indices = torch.topk(probs, 5)

print("Top 5 predicted next tokens (before training):")
for prob, idx in zip(top_probs, top_indices):
    print(f"  Token {idx.item():>5}: {prob.item():.4f} probability")

print("\nThese are random because the model is untrained!")
print(f"With 10,000 tokens, uniform probability = {1/10000:.5f}")
Top 5 predicted next tokens (before training):
  Token   339: 0.0003 probability
  Token  7951: 0.0003 probability
  Token  8281: 0.0003 probability
  Token  2075: 0.0003 probability
  Token  1483: 0.0003 probability

These are random because the model is untrained!
With 10,000 tokens, uniform probability = 0.00010

A Single Training Step

Let’s see what one training step looks like. We’ll compute the loss and update the weights.

# Setup
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Training uses next-token prediction:
# Input:  [token0, token1, token2, token3, ...]
# Target: [token1, token2, token3, token4, ...]  (shifted by 1)

# We predict each token from the preceding tokens
input_tokens = tokens[:, :-1]   # All but last
target_tokens = tokens[:, 1:]   # All but first

print(f"Input:  {input_tokens.shape}")
print(f"Target: {target_tokens.shape}")
Input:  torch.Size([2, 15])
Target: torch.Size([2, 15])
# Forward pass
logits = model(input_tokens)

# Compute cross-entropy loss
# Reshape for cross_entropy: (batch × seq_len, vocab_size) vs (batch × seq_len)
loss = F.cross_entropy(
    logits.view(-1, vocab_size),
    target_tokens.reshape(-1)
)

print(f"Loss: {loss.item():.4f}")
print(f"\nExpected loss for random predictions: {math.log(vocab_size):.2f}")
print(f"(Cross-entropy of uniform distribution over {vocab_size} tokens)")
Loss: 9.2700

Expected loss for random predictions: 9.21
(Cross-entropy of uniform distribution over 10000 tokens)
# Backward pass
optimizer.zero_grad()
loss.backward()

# Check gradient magnitudes
total_grad_norm = 0
for p in model.parameters():
    if p.grad is not None:
        total_grad_norm += p.grad.norm().item() ** 2
total_grad_norm = total_grad_norm ** 0.5

print(f"Total gradient norm: {total_grad_norm:.4f}")
Total gradient norm: 8.3218
# Update weights
optimizer.step()

print("Completed one training step!")
print("\nTo train a language model:")
print("  - Repeat this millions of times")
print("  - On real text data")
print("  - With techniques like gradient accumulation")
Completed one training step!

To train a language model:
  - Repeat this millions of times
  - On real text data
  - With techniques like gradient accumulation

Model Scale Comparison

Our model is tiny. Here’s how it compares:

ModelLayersdmodeld_{model}HeadsParameters
Ours42564~5M
GPT-2 Small1276812117M
GPT-2 Large36128020774M
GPT-3961228896175B
Claude/GPT-4~?~?~?~1-2T

But the architecture is identical. The same attention, the same FFN, the same residuals. GPT-3 is just our model with bigger matrices and more layers.

Key Takeaways

  1. Five stages: Embed → Blocks → LayerNorm → Project → Logits

  2. Weight tying shares parameters between embedding and output projection

  3. Logits are raw scores; apply softmax for probabilities

  4. Training predicts each token from preceding tokens (shifted targets)

  5. Architecture scales: Same fundamentals work from 5M to 1T parameters

Next: Training at Scale

We have a complete model. But training it well requires techniques beyond basic gradient descent:

  • Gradient accumulation: Simulate larger batches without more memory

  • Validation: Detect overfitting before it’s too late

  • Learning rate schedules: Warm up, then decay

In the next notebook, we’ll cover practical training strategies.