Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

The Complete Model

Putting It All Together

We’ve built every piece. Now let’s assemble them into a complete, working language model.

We’re building a decoder-only transformer—the architecture used by GPT, Claude, and LLaMA. “Decoder-only” means it generates text autoregressively: one token at a time, each prediction conditioned on all previous tokens.

The original transformer paper had both an encoder (for reading input) and decoder (for generating output), designed for translation. Researchers discovered that the decoder alone works beautifully for language modeling, and it’s simpler. The key feature is causal masking—each position can only see previous positions, not future ones.

Data Flow Through the Model

Token IDs: [101, 2054, 2003, 2115, 2171]
                    ↓
┌─────────────────────────────────────┐
│       Token Embedding               │  Look up vectors
│       + Position Embedding          │  Add position info
│       + Dropout                     │
└─────────────────────────────────────┘
                    ↓
┌─────────────────────────────────────┐
│     Transformer Block × N           │  The core processing
│                                     │
│  (attention + FFN, with residuals)  │
└─────────────────────────────────────┘
                    ↓
┌─────────────────────────────────────┐
│       Final LayerNorm               │  Stabilize outputs
└─────────────────────────────────────┘
                    ↓
┌─────────────────────────────────────┐
│       Output Projection             │  d_model → vocab_size
└─────────────────────────────────────┘
                    ↓
Logits: [0.1, 0.3, 8.2, ...]           Scores for each token

Five stages:

  1. Embedding: Convert token IDs to vectors, add position information

  2. Transformer Blocks: N layers of attention + FFN (we use 4; GPT-3 uses 96)

  3. Final LayerNorm: One last normalization for stability

  4. Output Projection: Map from dmodeld_{model} to vocabulary size

  5. Output Logits: Raw scores for each possible next token

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# Components from previous notebooks

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        context = torch.matmul(attn, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.W_o(context)


class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.linear2(F.gelu(self.linear1(x))))


class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attn_norm = nn.LayerNorm(d_model)
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn_norm = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        x = x + self.dropout(self.attention(self.attn_norm(x), mask))
        x = x + self.dropout(self.ffn(self.ffn_norm(x)))
        return x

The Complete GPT Model

class GPT(nn.Module):
    """
    A complete decoder-only transformer language model.
    
    This is the same architecture used by GPT-2, GPT-3, and Claude,
    just smaller. The fundamentals are identical.
    """
    
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 256,
        num_heads: int = 4,
        num_layers: int = 4,
        d_ff: int = 1024,
        max_seq_len: int = 512,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.d_model = d_model
        self.vocab_size = vocab_size
        
        # Token and position embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        self.embed_dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Final layer norm
        self.ln_f = nn.LayerNorm(d_model)
        
        # Output projection: d_model → vocab_size
        self.output_proj = nn.Linear(d_model, vocab_size, bias=False)
        
        # Weight tying: share weights between embedding and output
        # This is a common technique that reduces parameters and improves quality
        self.output_proj.weight = self.token_embedding.weight
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights with small random values."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(
        self, 
        token_ids: torch.Tensor,
        mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            token_ids: (batch, seq_len) - integer token IDs
            mask: optional causal mask
        
        Returns:
            logits: (batch, seq_len, vocab_size) - scores for each token
        """
        batch_size, seq_len = token_ids.shape
        device = token_ids.device
        
        # Create causal mask if not provided
        if mask is None:
            mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
        
        # 1. Embeddings
        positions = torch.arange(seq_len, device=device)
        x = self.token_embedding(token_ids) + self.position_embedding(positions)
        x = self.embed_dropout(x)
        
        # 2. Transformer blocks
        for block in self.blocks:
            x = block(x, mask)
        
        # 3. Final layer norm
        x = self.ln_f(x)
        
        # 4. Project to vocabulary
        logits = self.output_proj(x)
        
        return logits
# Create the model
model = GPT(
    vocab_size=10000,
    d_model=256,
    num_heads=4,
    num_layers=4,
    d_ff=1024,
    max_seq_len=512,
    dropout=0.1
)

print("GPT Model Configuration")
print("=" * 40)
print(f"  vocab_size:   {model.vocab_size:,}")
print(f"  d_model:      {model.d_model}")
print(f"  num_heads:    4")
print(f"  num_layers:   4")
print(f"  d_ff:         1024 (4 × d_model)")
print(f"  max_seq_len:  512")
GPT Model Configuration
========================================
  vocab_size:   10,000
  d_model:      256
  num_heads:    4
  num_layers:   4
  d_ff:         1024 (4 × d_model)
  max_seq_len:  512

Counting Every Parameter

Let’s count exactly where all the parameters come from:

def count_params(module):
    return sum(p.numel() for p in module.parameters())

# Manual calculation
vocab_size = 10000
d_model = 256
num_heads = 4
num_layers = 4
d_ff = 1024
max_seq_len = 512

print("Parameter Breakdown")
print("=" * 60)

# Embeddings
token_embed = vocab_size * d_model
pos_embed = max_seq_len * d_model
print(f"Token embeddings:      {vocab_size} × {d_model} = {token_embed:>12,}")
print(f"Position embeddings:   {max_seq_len} × {d_model} = {pos_embed:>12,}")

# Per block
attn_params = 4 * d_model * d_model  # Q, K, V, O
ffn_params = 2 * d_model * d_ff + d_ff + d_model  # up + down + biases
ln_params = 4 * d_model  # 2 layer norms × (gamma + beta)
block_params = attn_params + ffn_params + ln_params

print(f"\nPer transformer block:")
print(f"  Attention (Q,K,V,O):  4 × {d_model}² = {attn_params:>12,}")
print(f"  FFN:                  {ffn_params:>12,}")
print(f"  LayerNorm (×2):       {ln_params:>12,}")
print(f"  Block total:          {block_params:>12,}")

# All blocks
all_blocks = num_layers * block_params
print(f"\nAll {num_layers} blocks:          {all_blocks:>12,}")

# Final LN
final_ln = 2 * d_model
print(f"Final LayerNorm:        {final_ln:>12,}")

# Output projection (tied with token embedding - no additional params)
print(f"Output projection:      (tied with token embeddings)")

# Total
total_calculated = token_embed + pos_embed + all_blocks + final_ln
total_actual = count_params(model)

print("=" * 60)
print(f"Total (calculated):     {total_calculated:>12,}")
print(f"Total (actual):         {total_actual:>12,}")
Parameter Breakdown
============================================================
Token embeddings:      10000 × 256 =    2,560,000
Position embeddings:   512 × 256 =      131,072

Per transformer block:
  Attention (Q,K,V,O):  4 × 256² =      262,144
  FFN:                       525,568
  LayerNorm (×2):              1,024
  Block total:               788,736

All 4 blocks:             3,154,944
Final LayerNorm:                 512
Output projection:      (tied with token embeddings)
============================================================
Total (calculated):        5,846,528
Total (actual):            5,846,528

Testing the Forward Pass

# Generate random token IDs (simulating tokenized text)
batch_size = 2
seq_len = 16
tokens = torch.randint(0, vocab_size, (batch_size, seq_len))

# Forward pass
logits = model(tokens)

print(f"Input tokens shape:  {tokens.shape}")
print(f"Output logits shape: {logits.shape}")
print(f"\nFor each position, we get {vocab_size} scores (one per vocabulary token).")
Input tokens shape:  torch.Size([2, 16])
Output logits shape: torch.Size([2, 16, 10000])

For each position, we get 10000 scores (one per vocabulary token).

What Are Logits?

The model outputs logits—raw, unnormalized scores for each token in the vocabulary. Higher scores mean the model thinks that token is more likely to come next.

To convert logits to probabilities, apply softmax:

P(tokeni)=elogitijelogitjP(\text{token}_i) = \frac{e^{\text{logit}_i}}{\sum_j e^{\text{logit}_j}}

Then you can either:

  • Greedy decoding: Pick the highest-probability token

  • Sampling: Randomly sample from the distribution (for variety)

  • Top-k/Top-p sampling: Sample from the most likely tokens

# Convert logits to probabilities
last_position_logits = logits[0, -1, :]  # Last position of first batch
probs = F.softmax(last_position_logits, dim=-1)

# Top 5 predictions
top_probs, top_indices = torch.topk(probs, 5)

print("Top 5 predicted next tokens (before training):")
for prob, idx in zip(top_probs, top_indices):
    print(f"  Token {idx.item():>5}: {prob.item():.4f} probability")

print("\nThese are random because the model is untrained!")
print(f"With 10,000 tokens, uniform probability = {1/10000:.5f}")
Top 5 predicted next tokens (before training):
  Token  4856: 0.0003 probability
  Token  2264: 0.0003 probability
  Token  5763: 0.0003 probability
  Token  5349: 0.0003 probability
  Token  3286: 0.0003 probability

These are random because the model is untrained!
With 10,000 tokens, uniform probability = 0.00010

A Single Training Step

Let’s see what one training step looks like. We’ll compute the loss and update the weights.

# Setup
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Training uses next-token prediction:
# Input:  [token0, token1, token2, token3, ...]
# Target: [token1, token2, token3, token4, ...]  (shifted by 1)

# We predict each token from the preceding tokens
input_tokens = tokens[:, :-1]   # All but last
target_tokens = tokens[:, 1:]   # All but first

print(f"Input:  {input_tokens.shape}")
print(f"Target: {target_tokens.shape}")
Input:  torch.Size([2, 15])
Target: torch.Size([2, 15])
# Forward pass
logits = model(input_tokens)

# Compute cross-entropy loss
# Reshape for cross_entropy: (batch × seq_len, vocab_size) vs (batch × seq_len)
loss = F.cross_entropy(
    logits.view(-1, vocab_size),
    target_tokens.reshape(-1)
)

print(f"Loss: {loss.item():.4f}")
print(f"\nExpected loss for random predictions: {math.log(vocab_size):.2f}")
print(f"(Cross-entropy of uniform distribution over {vocab_size} tokens)")
Loss: 9.1634

Expected loss for random predictions: 9.21
(Cross-entropy of uniform distribution over 10000 tokens)
# Backward pass
optimizer.zero_grad()
loss.backward()

# Check gradient magnitudes
total_grad_norm = 0
for p in model.parameters():
    if p.grad is not None:
        total_grad_norm += p.grad.norm().item() ** 2
total_grad_norm = total_grad_norm ** 0.5

print(f"Total gradient norm: {total_grad_norm:.4f}")
Total gradient norm: 8.1622
# Update weights
optimizer.step()

print("Completed one training step!")
print("\nTo train a language model:")
print("  - Repeat this millions of times")
print("  - On real text data")
print("  - With techniques like gradient accumulation")
Completed one training step!

To train a language model:
  - Repeat this millions of times
  - On real text data
  - With techniques like gradient accumulation

Model Scale Comparison

Our model is tiny. Here’s how it compares:

ModelLayersdmodeld_{model}HeadsParameters
Ours42564~5M
GPT-2 Small1276812117M
GPT-2 Large36128020774M
GPT-3961228896175B
Claude/GPT-4~?~?~?~1-2T

But the architecture is identical. The same attention, the same FFN, the same residuals. GPT-3 is just our model with bigger matrices and more layers.

Key Takeaways

  1. Five stages: Embed → Blocks → LayerNorm → Project → Logits

  2. Weight tying shares parameters between embedding and output projection

  3. Logits are raw scores; apply softmax for probabilities

  4. Training predicts each token from preceding tokens (shifted targets)

  5. Architecture scales: Same fundamentals work from 5M to 1T parameters

Next: Training at Scale

We have a complete model. But training it well requires techniques beyond basic gradient descent:

  • Gradient accumulation: Simulate larger batches without more memory

  • Validation: Detect overfitting before it’s too late

  • Learning rate schedules: Warm up, then decay

In the next notebook, we’ll cover practical training strategies.