Position-Wise Feed-Forward Networks
What is a feed-forward network? After attention gathers information from across the sequence, we need to actually process that information. The feed-forward network (FFN) is a simple two-layer neural network—also called a Multi-Layer Perceptron (MLP)—that transforms each token’s representation independently.
Why do we need it? Think of the attention layer as “communication” between tokens—gathering relevant context. The FFN is the “computation” step—processing that gathered information to extract useful features and patterns. Without the FFN, the model would only shuffle information around without transforming it.
The Architecture
Section titled “The Architecture”-
Expand: Project from d_model (e.g., 512) to d_ff (typically 4× larger, e.g., 2048). This expansion gives the model more “capacity” to learn complex patterns.
-
Activate: Apply GELU activation—a smooth nonlinear function that allows the model to learn non-linear relationships. Without this nonlinearity, stacking layers would be pointless (multiple linear transformations collapse to one).
-
Project back: Compress back down from d_ff to d_model so the output shape matches the input, allowing us to stack more layers.
Position-wise: Crucially, the same FFN (same weights) is applied to every position independently. This is efficient and helps the model learn general transformations that work regardless of position.
Implementation
Section titled “Implementation”class FeedForward(nn.Module): """Position-wise feed-forward network."""
def __init__(self, d_model, d_ff, dropout=0.1): super().__init__()
# Expand dimension self.linear1 = nn.Linear(d_model, d_ff)
# GELU activation (used in GPT-2, GPT-3) self.activation = nn.GELU() self.dropout1 = nn.Dropout(dropout)
# Project back to d_model self.linear2 = nn.Linear(d_ff, d_model) self.dropout2 = nn.Dropout(dropout)
def forward(self, x): # x: (batch, seq_len, d_model) x = self.linear1(x) # → (batch, seq_len, d_ff) x = self.activation(x) x = self.dropout1(x) x = self.linear2(x) # → (batch, seq_len, d_model) x = self.dropout2(x) return xFull Code
Section titled “Full Code”See the full implementation: src/transformer/feedforward.py