How LLMs Work

The Attention Mechanism

13m read

The Attention Mechanism

The attention mechanism is the core innovation that makes transformers powerful. It allows the model to dynamically weigh the importance of different parts of the input when processing each token — giving the model a form of "selective focus" that humans find intuitive.

The Intuition

Imagine translating the sentence: "The animal didn't cross the street because it was too tired."

What does "it" refer to? To answer this, you instinctively look back at the sentence and focus your attention on "animal" — not "street." The attention mechanism teaches the model to do the same: when processing "it," heavily weight the representation of "animal."

Query, Key, Value: A Database Analogy

Attention is often explained using a soft database lookup analogy:

  • Query (Q): The current token asks a question — "What context is relevant to understanding me?"
  • Key (K): Every other token answers — "Here's what I'm about"
  • Value (V): The actual information to retrieve — "Here's what I contribute"

The model computes the similarity between the current query and all keys, uses those similarities as weights, and returns a weighted sum of values.

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Q: (batch, heads, seq_len, d_k) — queries
    K: (batch, heads, seq_len, d_k) — keys  
    V: (batch, heads, seq_len, d_v) — values
    Returns: (batch, heads, seq_len, d_v) — context vectors
    """
    d_k = Q.size(-1)
    
    # Compute similarity scores: (batch, heads, seq_len, seq_len)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
    
    # Apply causal mask (decoder only sees past tokens)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Convert to probabilities
    attention_weights = F.softmax(scores, dim=-1)
    
    # Weighted sum of values
    output = torch.matmul(attention_weights, V)
    return output, attention_weights

Causal Masking in Decoder Models

GPT-style models (decoder-only) can only attend to past tokens, not future ones. This is enforced via a causal mask — a triangular matrix that sets future positions to negative infinity before the softmax, driving their attention weights to zero.

Attention mask for a 4-token sequence:
Token 1 can attend to: [1, -, -, -]
Token 2 can attend to: [1, 2, -, -]
Token 3 can attend to: [1, 2, 3, -]
Token 4 can attend to: [1, 2, 3, 4]

This is what makes autoregressive generation work: each generated token can only "see" what came before it.

Multi-Head Attention

Instead of computing attention once, transformers compute it in parallel across multiple heads — each with its own learned Q, K, V projection matrices:

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        
        # One projection per role, shared across heads
        self.W_q = torch.nn.Linear(d_model, d_model)
        self.W_k = torch.nn.Linear(d_model, d_model)
        self.W_v = torch.nn.Linear(d_model, d_model)
        self.W_o = torch.nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        
        # Project and reshape for multi-head computation
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Attend in parallel across all heads
        output, _ = scaled_dot_product_attention(Q, K, V)
        
        # Concatenate heads and project back
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        return self.W_o(output)

What Different Heads Learn

Research into attention head specialization has found that different heads tend to capture different linguistic phenomena:

  • Syntactic heads: Track grammatical dependencies (subject-verb, adjective-noun)
  • Semantic heads: Capture meaning relationships (synonyms, antonyms)
  • Positional heads: Attend to nearby tokens (local context)
  • Rare token heads: Focus on infrequent but semantically significant tokens

This specialization emerges naturally during training — it's not explicitly programmed.

Practical Implications

Context window efficiency: Attention complexity is O(n²) in sequence length. Doubling the context window quadruples the attention computation. This is why extending context windows is expensive and why techniques like sliding window attention (Mistral) and sparse attention (Longformer) exist.

Prompt position matters: Tokens near the beginning and end of the context tend to receive more attention than middle tokens — a phenomenon called the "lost in the middle" effect. For long-context RAG, place the most important content at the edges.

Attention visualization: Tools like BertViz let you visualize which tokens attend to which, which can help debug unexpected model behavior on specific inputs.