Multi-Head Attention

A single attention operation lets each token focus on the others in one way. But language has many kinds of relationships happening at once — grammar, meaning, references, word order. Multi-head attention solves this by running several attention operations in parallel, each free to focus on something different, then combining their results. This is the form of attention actually used inside every Transformer.

💡 In one line: Multi-head attention runs several attentions in parallel — each learning a different pattern — and merges them for a richer representation.

Why Multiple Heads?

One attention produces just one set of weights — a single "perspective." But to understand a word, a token may need to attend to several things at once: the subject of the sentence, the main verb, what a pronoun refers to. A single head can't do all of this well.

Multiple heads give multiple attention patterns simultaneously, each capturing a different type of relationship.

How Multi-Head Attention Works

Instead of computing attention once, multi-head attention does this:

  1. Split into h heads — create h separate sets of Q/K/V, each with its own learned projections.
  2. Attend in parallel — each head runs scaled dot-product attention independently.
  3. Concatenate — join the h outputs back together.
  4. Project — pass the result through a final linear layer (W_O) to mix the heads.

Each head works in a smaller dimension: dₖ = d_model / h. So eight heads of size 64 reassemble into the original 512 — meaning multi-head attention costs about the same as one big attention, but is far more expressive.

The Intuition: Different Heads, Different Jobs

When researchers inspect trained Transformers, they find that heads specialise:

  • Some heads track syntax (e.g. verb → subject).
  • Some follow coreference (e.g. a pronoun → the noun it refers to).
  • Some simply attend to the previous or next word.

No single head learns everything — but together they build a much richer understanding than one head could.

Dimensions

With d_model = 512 and h = 8 heads:

  • Each head uses dₖ = 512 / 8 = 64.
  • The 8 heads produce 8 × 64 = 512 values, concatenated back to 512.
  • A final W_O projection mixes them.

The formula:

MultiHead(Q,K,V) = Concat(head₁, …, head_h) · W_O
   where  headᵢ = Attention(Q·W_Qⁱ, K·W_Kⁱ, V·W_Vⁱ)


Code Example


PyTorch's built-in module handles the splitting, parallel attention, concatenation, and projection for you. (Runs with pip install torch.)

Why It Matters

  • Richer representations — multiple relationship types captured at once.
  • Same cost — heads run in a smaller dimension, so it's no more expensive than single attention.
  • Standard everywhere — every Transformer (BERT, GPT, T5) uses multi-head attention.

Summary

  • Multi-head attention runs h attention operations in parallel, each with its own Q/K/V projections.
  • The heads specialise — different ones capture syntax, coreference, position, and more.
  • Each head works in dimension dₖ = d_model / h; outputs are concatenated and passed through W_O.
  • It gives a richer representation than single attention at roughly the same cost.
  • It's the standard attention used in all Transformers — applied as self-attention or cross-attention (next).