Multi-Head Attention
Capturing Multiple Types of Relationships
Why One Attention Head Isn't Enough
A single attention head computes one set of attention weights — one way of deciding which tokens are relevant to which. But language has many simultaneous types of relationships:
Consider the word "sat" in "The cat sat down":
- •Syntactic: "sat" needs to find its subject ("cat") and modifier ("down")
- •Semantic: "sat" relates to the concept of physical position
- •Positional: "sat" is near "cat" (adjacent) and "down" (adjacent)
A single attention head must compress all these relationship types into one set of weights. It might learn to focus on syntactic relationships but miss positional ones, or vice versa.
Multi-head attention solves this by running multiple attention heads in parallel, each with its own Q, K, V projections. Each head can specialize in a different relationship type:
- •Head 0 might learn syntactic dependencies (subject-verb)
- •Head 1 might learn positional/local patterns (adjacent words)
- •Head 2 might learn semantic similarity
- •Head 3 might learn coreference (pronoun resolution)
In practice, researchers find that different heads do specialize, though the patterns are often more nuanced than these clean categories.
One Head Cannot Capture All Relationship Types
| Relationship Type | Example in "The cat sat down" | Single Head Can Capture? |
|---|---|---|
| Subject-verb | "sat" ← "cat" (who sat?) | Maybe, if this is what it learns |
| Verb-modifier | "sat" ← "down" (how?) | Conflicts with subject-verb focus |
| Determiner-noun | "The" → "cat" (which cat?) | May be ignored if head focuses on verbs |
| Local context | Adjacent word patterns | May be missed for long-range focus |
Scaling Attention with Multiple Heads
| Approach | Heads | What Each Head Sees | Capacity |
|---|---|---|---|
| Single-head | 1 | One attention pattern for everything | Limited — must compromise |
| Multi-head (h=2) | 2 | Each head has own Q, K, V weights | 2 independent patterns |
| Multi-head (h=8) | 8 | Each head learns different relationships | 8 independent patterns |
| GPT-3 (h=96) | 96 | Rich, diverse attention patterns | Massive capacity |