The Building Block of Transformers¶
We now have the two core components:
Multi-Head Attention: Communication between positions
Feed-Forward Network: Computation at each position
But if we just stack these directly, training fails. Deep networks suffer from two problems:
Vanishing/exploding gradients: As we add layers, gradients get multiplied at each layer. They can shrink to zero or grow to infinity.
Internal covariate shift: The distribution of layer inputs keeps changing during training, making it hard to learn. When you update weights in layer 1, it changes the distribution of inputs to layer 2. Layer 2 has to constantly adapt to these shifting distributions instead of making stable progress. This is like trying to learn a new skill while the rules keep changing—you waste effort readjusting instead of improving. Layer normalization addresses this by standardizing each layer’s inputs to have consistent statistics (mean 0, variance 1), giving each layer a stable foundation to build on.
The solution: residual connections and layer normalization. These aren’t optional additions—they’re essential for training deep transformers.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as npResidual Connections: The Skip Path¶
A residual connection adds the input directly to the output:
Instead of learning the transformation directly, the sublayer learns the residual—what to add to the input.
Why does this help?
Consider the gradient flow. With a residual:
That “+1” is crucial. Even if the sublayer’s gradient vanishes, there’s still a direct path for gradients to flow backward. The network can always fall back to the identity function.
Visually:
x ─────────────────────────┐
│ │ (skip connection)
↓ │
[sublayer] │
│ │
↓ ↓
└──────────────[ + ]───────→ output# Demonstrating gradient flow with residuals
class WithResidual(nn.Module):
def __init__(self, sublayer):
super().__init__()
self.sublayer = sublayer
def forward(self, x):
return self.sublayer(x) + x # The key: add input to output
# Create a layer that nearly kills gradients
layer = nn.Linear(64, 64)
nn.init.normal_(layer.weight, std=0.001) # Very small weights = near-zero gradients
# Without residual
x = torch.randn(1, 64, requires_grad=True)
out_no_res = layer(x)
loss = out_no_res.sum()
loss.backward()
grad_norm_no_res = x.grad.norm().item()
# With residual
x2 = torch.randn(1, 64, requires_grad=True)
wrapped = WithResidual(layer)
out_with_res = wrapped(x2)
loss2 = out_with_res.sum()
loss2.backward()
grad_norm_with_res = x2.grad.norm().item()
print("Gradient magnitude comparison:")
print(f" Without residual: {grad_norm_no_res:.6f}")
print(f" With residual: {grad_norm_with_res:.6f}")
print(f"\nResidual connection gives {grad_norm_with_res/grad_norm_no_res:.0f}× stronger gradient!")Gradient magnitude comparison:
Without residual: 0.067134
With residual: 8.002588
Residual connection gives 119× stronger gradient!
Layer Normalization: Stabilizing Activations¶
Layer normalization normalizes activations across the feature dimension (not the batch dimension like batch norm).
For each position’s vector :
where:
(mean across features)
(variance across features)
are learned scale and shift parameters
is a small constant for numerical stability
What does it do?
It keeps activations in a consistent range:
Center around 0 (by subtracting mean)
Scale to unit variance (by dividing by std)
Then allow learned rescaling (, )
This prevents activations from drifting to extreme values during training.
class LayerNorm(nn.Module):
"""
Layer normalization.
Normalizes across the feature dimension (last dimension).
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
# Learned parameters
self.gamma = nn.Parameter(torch.ones(d_model)) # Scale
self.beta = nn.Parameter(torch.zeros(d_model)) # Shift
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute mean and variance across last dimension
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# Normalize
x_norm = (x - mean) / torch.sqrt(var + self.eps)
# Scale and shift
return self.gamma * x_norm + self.beta# Demonstrate layer norm
d_model = 256
ln = LayerNorm(d_model)
# Create input with varying scales
x = torch.randn(2, 5, d_model) * 10 + 5 # Mean ~5, std ~10
x_norm = ln(x)
print("Before LayerNorm:")
print(f" Mean: {x.mean(dim=-1)[0].detach().numpy().round(2)}")
print(f" Std: {x.std(dim=-1)[0].detach().numpy().round(2)}")
print()
print("After LayerNorm:")
print(f" Mean: {x_norm.mean(dim=-1)[0].detach().numpy().round(4)}")
print(f" Std: {x_norm.std(dim=-1)[0].detach().numpy().round(2)}")
print()
print("LayerNorm centers each position to mean≈0, std≈1")Before LayerNorm:
Mean: [5.11 5.48 4.57 4.77 4.15]
Std: [10.75 9.68 9.43 10.18 9.9 ]
After LayerNorm:
Mean: [ 0. -0. -0. 0. 0.]
Std: [1. 1. 1. 1. 1.]
LayerNorm centers each position to mean≈0, std≈1
Pre-Norm vs Post-Norm¶
There are two ways to combine layer norm with residuals:
Post-Norm (original transformer):
Pre-Norm (GPT-2 and later):
Post-Norm: Pre-Norm:
x ────┐ x ───────────────┐
│ │ │ │
↓ │ ↓ │
[sublayer] [LayerNorm] │
│ │ │ │
↓ ↓ ↓ │
[ + ]←─ [sublayer] │
│ │ │
↓ ↓ ↓
[LayerNorm] └────[ + ]←──────
│ │
↓ ↓
output outputPre-norm is more stable because the residual path is completely clean—no normalization interfering with gradient flow. This matters especially for very deep models.
The Complete Transformer Block¶
Now we can assemble the full block:
Input x
│
├───────────────────────────────┐
↓ │ (residual 1)
[LayerNorm] │
↓ │
[Multi-Head Attention] │
↓ │
[Dropout] │
↓ ↓
└───────────────[ + ]←──────────┘
│
┌────────────────┴──────────────┐
│ │ (residual 2)
↓ │
[LayerNorm] │
↓ │
[Feed-Forward Network] │
↓ │
[Dropout] │
↓ ↓
└───────────────[ + ]←──────────┘
│
↓
Outputclass MultiHeadAttention(nn.Module):
"""Multi-head attention (from previous notebook)"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
context = torch.matmul(attn, V)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.W_o(context)
class FeedForward(nn.Module):
"""Feed-forward network (from previous notebook)"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.linear2(F.gelu(self.linear1(x))))class TransformerBlock(nn.Module):
"""
A single transformer block.
Architecture (Pre-Norm):
x → LayerNorm → Attention → Dropout → (+x) →
→ LayerNorm → FFN → Dropout → (+) → output
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1
):
super().__init__()
# Attention sublayer
self.attn_norm = nn.LayerNorm(d_model)
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
self.attn_dropout = nn.Dropout(dropout)
# FFN sublayer
self.ffn_norm = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.ffn_dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, d_model)
mask: causal mask
Returns:
(batch, seq_len, d_model)
"""
# Attention block with residual
attn_out = self.attention(self.attn_norm(x), mask)
x = x + self.attn_dropout(attn_out)
# FFN block with residual
ffn_out = self.ffn(self.ffn_norm(x))
x = x + self.ffn_dropout(ffn_out)
return x# Test the transformer block
d_model = 256
num_heads = 4
d_ff = 1024
block = TransformerBlock(d_model, num_heads, d_ff, dropout=0.0)
# Input
x = torch.randn(2, 8, d_model)
# Causal mask
seq_len = 8
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
# Forward
output = block(x, mask)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nOutput has same shape as input ✓")
print(f"This allows us to stack blocks!")Input shape: torch.Size([2, 8, 256])
Output shape: torch.Size([2, 8, 256])
Output has same shape as input ✓
This allows us to stack blocks!
Parameter Count¶
Let’s break down the parameters in one block:
# Count parameters by component
def count_params(module):
return sum(p.numel() for p in module.parameters())
print("Transformer Block Parameters")
print("=" * 50)
print(f"Attention sublayer:")
print(f" LayerNorm: {count_params(block.attn_norm):>12,}")
print(f" Attention: {count_params(block.attention):>12,}")
print()
print(f"FFN sublayer:")
print(f" LayerNorm: {count_params(block.ffn_norm):>12,}")
print(f" FFN: {count_params(block.ffn):>12,}")
print("=" * 50)
print(f"Total: {count_params(block):>12,}")
print()
# Show breakdown
attn_params = count_params(block.attention)
ffn_params = count_params(block.ffn)
total_params = count_params(block)
print(f"Proportion:")
print(f" Attention: {attn_params/total_params*100:.1f}%")
print(f" FFN: {ffn_params/total_params*100:.1f}%")Transformer Block Parameters
==================================================
Attention sublayer:
LayerNorm: 512
Attention: 262,144
FFN sublayer:
LayerNorm: 512
FFN: 525,568
==================================================
Total: 788,736
Proportion:
Attention: 33.2%
FFN: 66.6%
Stacking Blocks¶
The power of transformers comes from stacking multiple blocks. Each block refines the representations:
Layer 1: Basic patterns (adjacent words, simple grammar)
Layer 2-3: Higher-level patterns (phrases, basic semantics)
Layer 4+: Abstract reasoning (long-range dependencies, complex relations)
More layers = more refinement = better understanding.
class TransformerStack(nn.Module):
"""
Stack of transformer blocks.
"""
def __init__(
self,
num_layers: int,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1
):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
for layer in self.layers:
x = layer(x, mask)
return x
# Create a stack
stack = TransformerStack(
num_layers=4,
d_model=256,
num_heads=4,
d_ff=1024,
dropout=0.0
)
# Test
x = torch.randn(2, 8, 256)
output = stack(x, mask)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nTotal parameters in 4-layer stack: {count_params(stack):,}")Input shape: torch.Size([2, 8, 256])
Output shape: torch.Size([2, 8, 256])
Total parameters in 4-layer stack: 3,154,944
Why Residuals and LayerNorm Matter¶
Let’s see what happens without them:
# Compare gradient flow with and without residuals
class BlockWithoutResidual(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForward(d_model, d_ff)
def forward(self, x, mask=None):
x = self.attention(x, mask)
x = self.ffn(x)
return x
# Stack 8 blocks without residuals
blocks_no_res = nn.Sequential(*[
BlockWithoutResidual(256, 4, 1024) for _ in range(8)
])
# Stack 8 blocks with residuals (our proper implementation)
blocks_with_res = TransformerStack(8, 256, 4, 1024, dropout=0.0)
# Test gradient magnitude through the network
x = torch.randn(1, 8, 256, requires_grad=True)
# Without residuals
out = blocks_no_res(x)
loss = out.sum()
loss.backward()
grad_no_res = x.grad.norm().item()
# With residuals
x2 = torch.randn(1, 8, 256, requires_grad=True)
out2 = blocks_with_res(x2, None)
loss2 = out2.sum()
loss2.backward()
grad_with_res = x2.grad.norm().item()
print("Gradient magnitude after 8 layers:")
print(f" Without residuals: {grad_no_res:.6f}")
print(f" With residuals: {grad_with_res:.6f}")
print(f"\nResiduals preserve gradient signal ~{grad_with_res/grad_no_res:.0f}× better!")Gradient magnitude after 8 layers:
Without residuals: 0.000000
With residuals: 73.678993
Residuals preserve gradient signal ~11590260933× better!
Key Takeaways¶
Residual connections add input directly to output, preserving gradient flow
Layer normalization keeps activations in a stable range
Pre-norm (normalize before sublayer) is more stable than post-norm
Each block has the same structure: Attention + FFN, each with residual and norm
Stacking blocks builds increasingly abstract representations
Next: Complete Model¶
We have the transformer block—the fundamental building unit. Now we need to wrap it with:
Input embeddings (token + position)
Output projection (to vocabulary logits)
Final layer norm
In the next notebook, we’ll assemble the complete model and count every parameter.