Skip to content

Position-Wise Feed-Forward Networks

What is a feed-forward network? After attention gathers information from across the sequence, we need to actually process that information. The feed-forward network (FFN) is a simple two-layer neural network—also called a Multi-Layer Perceptron (MLP)—that transforms each token’s representation independently.

Why do we need it? Think of the attention layer as “communication” between tokens—gathering relevant context. The FFN is the “computation” step—processing that gathered information to extract useful features and patterns. Without the FFN, the model would only shuffle information around without transforming it.

  1. Expand: Project from d_model (e.g., 512) to d_ff (typically 4× larger, e.g., 2048). This expansion gives the model more “capacity” to learn complex patterns.

  2. Activate: Apply GELU activation—a smooth nonlinear function that allows the model to learn non-linear relationships. Without this nonlinearity, stacking layers would be pointless (multiple linear transformations collapse to one).

  3. Project back: Compress back down from d_ff to d_model so the output shape matches the input, allowing us to stack more layers.

Position-wise: Crucially, the same FFN (same weights) is applied to every position independently. This is efficient and helps the model learn general transformations that work regardless of position.

class FeedForward(nn.Module):
"""Position-wise feed-forward network."""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
# Expand dimension
self.linear1 = nn.Linear(d_model, d_ff)
# GELU activation (used in GPT-2, GPT-3)
self.activation = nn.GELU()
self.dropout1 = nn.Dropout(dropout)
# Project back to d_model
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
# x: (batch, seq_len, d_model)
x = self.linear1(x) # → (batch, seq_len, d_ff)
x = self.activation(x)
x = self.dropout1(x)
x = self.linear2(x) # → (batch, seq_len, d_model)
x = self.dropout2(x)
return x

See the full implementation: src/transformer/feedforward.py