Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

The Transformer Block

The Building Block of Transformers

We now have the two core components:

  • Multi-Head Attention: Communication between positions

  • Feed-Forward Network: Computation at each position

But if we just stack these directly, training fails. Deep networks suffer from two problems:

  1. Vanishing/exploding gradients: As we add layers, gradients get multiplied at each layer. They can shrink to zero or grow to infinity.

  2. Internal covariate shift: The distribution of layer inputs keeps changing during training, making it hard to learn. When you update weights in layer 1, it changes the distribution of inputs to layer 2. Layer 2 has to constantly adapt to these shifting distributions instead of making stable progress. This is like trying to learn a new skill while the rules keep changing—you waste effort readjusting instead of improving. Layer normalization addresses this by standardizing each layer’s inputs to have consistent statistics (mean 0, variance 1), giving each layer a stable foundation to build on.

The solution: residual connections and layer normalization. These aren’t optional additions—they’re essential for training deep transformers.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

Residual Connections: The Skip Path

A residual connection adds the input directly to the output:

output=sublayer(x)+x\text{output} = \text{sublayer}(x) + x

Instead of learning the transformation directly, the sublayer learns the residual—what to add to the input.

Why does this help?

Consider the gradient flow. With a residual:

outputx=sublayer(x)x+1\frac{\partial \text{output}}{\partial x} = \frac{\partial \text{sublayer}(x)}{\partial x} + 1

That “+1” is crucial. Even if the sublayer’s gradient vanishes, there’s still a direct path for gradients to flow backward. The network can always fall back to the identity function.

Visually:

    x ─────────────────────────┐
    │                          │ (skip connection)
    ↓                          │
[sublayer]                     │
    │                          │
    ↓                          ↓
    └──────────────[ + ]───────→ output
# Demonstrating gradient flow with residuals
class WithResidual(nn.Module):
    def __init__(self, sublayer):
        super().__init__()
        self.sublayer = sublayer
    
    def forward(self, x):
        return self.sublayer(x) + x  # The key: add input to output

# Create a layer that nearly kills gradients
layer = nn.Linear(64, 64)
nn.init.normal_(layer.weight, std=0.001)  # Very small weights = near-zero gradients

# Without residual
x = torch.randn(1, 64, requires_grad=True)
out_no_res = layer(x)
loss = out_no_res.sum()
loss.backward()
grad_norm_no_res = x.grad.norm().item()

# With residual
x2 = torch.randn(1, 64, requires_grad=True)
wrapped = WithResidual(layer)
out_with_res = wrapped(x2)
loss2 = out_with_res.sum()
loss2.backward()
grad_norm_with_res = x2.grad.norm().item()

print("Gradient magnitude comparison:")
print(f"  Without residual: {grad_norm_no_res:.6f}")
print(f"  With residual:    {grad_norm_with_res:.6f}")
print(f"\nResidual connection gives {grad_norm_with_res/grad_norm_no_res:.0f}× stronger gradient!")
Gradient magnitude comparison:
  Without residual: 0.067134
  With residual:    8.002588

Residual connection gives 119× stronger gradient!

Layer Normalization: Stabilizing Activations

Layer normalization normalizes activations across the feature dimension (not the batch dimension like batch norm).

For each position’s vector xRdmodelx \in \mathbb{R}^{d_{model}}:

LayerNorm(x)=γxμσ2+ϵ+β\text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

where:

  • μ=1dixi\mu = \frac{1}{d} \sum_i x_i (mean across features)

  • σ2=1di(xiμ)2\sigma^2 = \frac{1}{d} \sum_i (x_i - \mu)^2 (variance across features)

  • γ,β\gamma, \beta are learned scale and shift parameters

  • ϵ\epsilon is a small constant for numerical stability

What does it do?

It keeps activations in a consistent range:

  • Center around 0 (by subtracting mean)

  • Scale to unit variance (by dividing by std)

  • Then allow learned rescaling (γ\gamma, β\beta)

This prevents activations from drifting to extreme values during training.

class LayerNorm(nn.Module):
    """
    Layer normalization.
    
    Normalizes across the feature dimension (last dimension).
    """
    
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        
        # Learned parameters
        self.gamma = nn.Parameter(torch.ones(d_model))   # Scale
        self.beta = nn.Parameter(torch.zeros(d_model))   # Shift
        self.eps = eps
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute mean and variance across last dimension
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        
        # Normalize
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        
        # Scale and shift
        return self.gamma * x_norm + self.beta
# Demonstrate layer norm
d_model = 256
ln = LayerNorm(d_model)

# Create input with varying scales
x = torch.randn(2, 5, d_model) * 10 + 5  # Mean ~5, std ~10

x_norm = ln(x)

print("Before LayerNorm:")
print(f"  Mean: {x.mean(dim=-1)[0].detach().numpy().round(2)}")
print(f"  Std:  {x.std(dim=-1)[0].detach().numpy().round(2)}")
print()
print("After LayerNorm:")
print(f"  Mean: {x_norm.mean(dim=-1)[0].detach().numpy().round(4)}")
print(f"  Std:  {x_norm.std(dim=-1)[0].detach().numpy().round(2)}")
print()
print("LayerNorm centers each position to mean≈0, std≈1")
Before LayerNorm:
  Mean: [5.11 5.48 4.57 4.77 4.15]
  Std:  [10.75  9.68  9.43 10.18  9.9 ]

After LayerNorm:
  Mean: [ 0. -0. -0.  0.  0.]
  Std:  [1. 1. 1. 1. 1.]

LayerNorm centers each position to mean≈0, std≈1

Pre-Norm vs Post-Norm

There are two ways to combine layer norm with residuals:

Post-Norm (original transformer):

x=LayerNorm(x+sublayer(x))x = \text{LayerNorm}(x + \text{sublayer}(x))

Pre-Norm (GPT-2 and later):

x=x+sublayer(LayerNorm(x))x = x + \text{sublayer}(\text{LayerNorm}(x))
Post-Norm:                    Pre-Norm:
x ────┐                       x ───────────────┐
│     │                       │                │
↓     │                       ↓                │
[sublayer]                [LayerNorm]          │
│     │                       │                │
↓     ↓                       ↓                │
[ + ]←─                   [sublayer]           │
│                             │                │
↓                             ↓                ↓
[LayerNorm]                   └────[ + ]←──────
│                                  │
↓                                  ↓
output                          output

Pre-norm is more stable because the residual path is completely clean—no normalization interfering with gradient flow. This matters especially for very deep models.

The Complete Transformer Block

Now we can assemble the full block:

Input x
    │
    ├───────────────────────────────┐
    ↓                               │ (residual 1)
[LayerNorm]                         │
    ↓                               │
[Multi-Head Attention]              │
    ↓                               │
[Dropout]                           │
    ↓                               ↓
    └───────────────[ + ]←──────────┘
                     │
    ┌────────────────┴──────────────┐
    │                               │ (residual 2)
    ↓                               │
[LayerNorm]                         │
    ↓                               │
[Feed-Forward Network]              │
    ↓                               │
[Dropout]                           │
    ↓                               ↓
    └───────────────[ + ]←──────────┘
                     │
                     ↓
                  Output
class MultiHeadAttention(nn.Module):
    """Multi-head attention (from previous notebook)"""
    
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        context = torch.matmul(attn, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.W_o(context)


class FeedForward(nn.Module):
    """Feed-forward network (from previous notebook)"""
    
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.linear2(F.gelu(self.linear1(x))))
class TransformerBlock(nn.Module):
    """
    A single transformer block.
    
    Architecture (Pre-Norm):
        x → LayerNorm → Attention → Dropout → (+x) → 
          → LayerNorm → FFN → Dropout → (+) → output
    """
    
    def __init__(
        self, 
        d_model: int, 
        num_heads: int, 
        d_ff: int, 
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Attention sublayer
        self.attn_norm = nn.LayerNorm(d_model)
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.attn_dropout = nn.Dropout(dropout)
        
        # FFN sublayer
        self.ffn_norm = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.ffn_dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, d_model)
            mask: causal mask
        
        Returns:
            (batch, seq_len, d_model)
        """
        # Attention block with residual
        attn_out = self.attention(self.attn_norm(x), mask)
        x = x + self.attn_dropout(attn_out)
        
        # FFN block with residual
        ffn_out = self.ffn(self.ffn_norm(x))
        x = x + self.ffn_dropout(ffn_out)
        
        return x
# Test the transformer block
d_model = 256
num_heads = 4
d_ff = 1024

block = TransformerBlock(d_model, num_heads, d_ff, dropout=0.0)

# Input
x = torch.randn(2, 8, d_model)

# Causal mask
seq_len = 8
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()

# Forward
output = block(x, mask)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nOutput has same shape as input ✓")
print(f"This allows us to stack blocks!")
Input shape:  torch.Size([2, 8, 256])
Output shape: torch.Size([2, 8, 256])

Output has same shape as input ✓
This allows us to stack blocks!

Parameter Count

Let’s break down the parameters in one block:

# Count parameters by component
def count_params(module):
    return sum(p.numel() for p in module.parameters())

print("Transformer Block Parameters")
print("=" * 50)
print(f"Attention sublayer:")
print(f"  LayerNorm:   {count_params(block.attn_norm):>12,}")
print(f"  Attention:   {count_params(block.attention):>12,}")
print()
print(f"FFN sublayer:")
print(f"  LayerNorm:   {count_params(block.ffn_norm):>12,}")
print(f"  FFN:         {count_params(block.ffn):>12,}")
print("=" * 50)
print(f"Total:         {count_params(block):>12,}")
print()

# Show breakdown
attn_params = count_params(block.attention)
ffn_params = count_params(block.ffn)
total_params = count_params(block)

print(f"Proportion:")
print(f"  Attention: {attn_params/total_params*100:.1f}%")
print(f"  FFN:       {ffn_params/total_params*100:.1f}%")
Transformer Block Parameters
==================================================
Attention sublayer:
  LayerNorm:            512
  Attention:        262,144

FFN sublayer:
  LayerNorm:            512
  FFN:              525,568
==================================================
Total:              788,736

Proportion:
  Attention: 33.2%
  FFN:       66.6%

Stacking Blocks

The power of transformers comes from stacking multiple blocks. Each block refines the representations:

  • Layer 1: Basic patterns (adjacent words, simple grammar)

  • Layer 2-3: Higher-level patterns (phrases, basic semantics)

  • Layer 4+: Abstract reasoning (long-range dependencies, complex relations)

More layers = more refinement = better understanding.

class TransformerStack(nn.Module):
    """
    Stack of transformer blocks.
    """
    
    def __init__(
        self, 
        num_layers: int,
        d_model: int, 
        num_heads: int, 
        d_ff: int, 
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, mask)
        return x

# Create a stack
stack = TransformerStack(
    num_layers=4,
    d_model=256,
    num_heads=4,
    d_ff=1024,
    dropout=0.0
)

# Test
x = torch.randn(2, 8, 256)
output = stack(x, mask)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nTotal parameters in 4-layer stack: {count_params(stack):,}")
Input shape: torch.Size([2, 8, 256])
Output shape: torch.Size([2, 8, 256])

Total parameters in 4-layer stack: 3,154,944

Why Residuals and LayerNorm Matter

Let’s see what happens without them:

# Compare gradient flow with and without residuals
class BlockWithoutResidual(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
    
    def forward(self, x, mask=None):
        x = self.attention(x, mask)
        x = self.ffn(x)
        return x

# Stack 8 blocks without residuals
blocks_no_res = nn.Sequential(*[
    BlockWithoutResidual(256, 4, 1024) for _ in range(8)
])

# Stack 8 blocks with residuals (our proper implementation)
blocks_with_res = TransformerStack(8, 256, 4, 1024, dropout=0.0)

# Test gradient magnitude through the network
x = torch.randn(1, 8, 256, requires_grad=True)

# Without residuals
out = blocks_no_res(x)
loss = out.sum()
loss.backward()
grad_no_res = x.grad.norm().item()

# With residuals
x2 = torch.randn(1, 8, 256, requires_grad=True)
out2 = blocks_with_res(x2, None)
loss2 = out2.sum()
loss2.backward()
grad_with_res = x2.grad.norm().item()

print("Gradient magnitude after 8 layers:")
print(f"  Without residuals: {grad_no_res:.6f}")
print(f"  With residuals:    {grad_with_res:.6f}")
print(f"\nResiduals preserve gradient signal ~{grad_with_res/grad_no_res:.0f}× better!")
Gradient magnitude after 8 layers:
  Without residuals: 0.000000
  With residuals:    73.678993

Residuals preserve gradient signal ~11590260933× better!

Key Takeaways

  1. Residual connections add input directly to output, preserving gradient flow

  2. Layer normalization keeps activations in a stable range

  3. Pre-norm (normalize before sublayer) is more stable than post-norm

  4. Each block has the same structure: Attention + FFN, each with residual and norm

  5. Stacking blocks builds increasingly abstract representations

Next: Complete Model

We have the transformer block—the fundamental building unit. Now we need to wrap it with:

  • Input embeddings (token + position)

  • Output projection (to vocabulary logits)

  • Final layer norm

In the next notebook, we’ll assemble the complete model and count every parameter.