Self-Attention Mechanism

A complete, in-depth exploration of the attention mechanism that revolutionized deep learning and enabled modern large language models

📚 Comprehensive Deep Dive ⏱️ ~2 hours 🎯 Interactive Visualizations 🐍 PyTorch Code

📖 Table of Contents

  1. 1 Historical Context & Motivation
  2. 2 The Core Intuition
  3. 3 Query, Key, Value: Deep Dive
  4. 4 Mathematical Foundations
  5. 5 Interactive Visualization
  6. 6 PyTorch Implementation
  7. 7 Causal Masking
  8. 8 Computational Complexity
  9. 9 Attention Variants
  10. 10 Important Q&A
  11. 11 Your Questions
  12. 12 Comprehension Quiz

1. Historical Context & Motivation

The Problem with Sequential Models

Before transformers, the dominant architectures for sequence processing were Recurrent Neural Networks (RNNs) and their variants (LSTMs, GRUs). These models processed sequences one token at a time, maintaining a hidden state that accumulated information:

ht = f(ht-1, xt)
RNN hidden state depends on previous hidden state and current input

This sequential dependency created several critical problems:

The Attention Revolution

Attention was first introduced in 2014 for machine translation (Bahdanau et al.). Instead of compressing all source information into a single vector, the decoder could "look back" at all encoder states and decide which ones were relevant.

The 2017 paper "Attention Is All You Need" (Vaswani et al.) took this further: what if we used ONLY attention, with no recurrence at all? This gave birth to the Transformer architecture.

💡 The Fundamental Insight

Self-attention allows every position in a sequence to directly attend to every other position in a single operation. Information doesn't need to "travel" through intermediate states - position 1 can directly influence position 100 with equal ease.

2. The Core Intuition

Attention as Dynamic, Content-Based Routing

Think of self-attention as a sophisticated information routing system. For each position in the sequence, we ask: "Given what I'm trying to compute here, which other positions in this sequence contain relevant information, and how should I combine that information?"

Unlike convolutions (which have fixed receptive fields) or RNNs (which have fixed routing through time), attention routing is:

The Database Analogy

A useful mental model is to think of attention like a fuzzy database lookup:

Query: "What am I looking for?"

Each position generates a query vector that encodes what information it needs. In a sentence like "The cat sat on the mat", when processing "sat", the query might encode "I need to find the subject of this action."

Key: "What do I contain?"

Each position also generates a key vector that advertises what information it holds. The key for "cat" might encode "I am a noun, likely a subject." The key for "the" might encode "I am a determiner."

Matching: Query · Key = Relevance Score

We compute dot products between the query and all keys. High dot product = high similarity = high relevance. The query for "sat" will have high similarity with the key for "cat" (its subject).

Value: "Here's my actual content"

Each position has a value vector containing its actual information to share. After determining relevance scores, we take a weighted sum of all values, where weights come from the relevance scores.

🔬 Why Separate Keys and Values?

You might wonder: why not just use the same vector for keys and values? The separation allows the model to independently learn:

  • What to match on (keys) - the "address" or "index" properties
  • What to retrieve (values) - the actual "content" to aggregate

This is similar to how a hash table has separate keys (for lookup) and values (for storage). A word's syntactic role (encoded in key) might differ from its semantic content (encoded in value).

3. Query, Key, Value: Deep Dive

The Linear Projections

Given an input sequence X ∈ ℝn×d (n tokens, each with d-dimensional embedding), we create Q, K, V through learned linear transformations:

Q = X · WQ ∈ ℝn×dk

K = X · WK ∈ ℝn×dk

V = X · WV ∈ ℝn×dv
WQ, WK ∈ ℝd×dk and WV ∈ ℝd×dv are learned parameters

Dimension Choices

In the original transformer and most modern LLMs:

Interactive: QKV Vectors

Query, Key, Value Projection Visualization

For the token "cat", we project the embedding through three different learned matrices:

Input: "cat"
0.2
-0.5
0.8
0.1
...
d=768 dimensions
↓ × WQ, WK, WV
Query
0.7
-0.2
0.4
...
dk=64
Key
0.3
0.9
-0.1
...
dk=64
Value
-0.4
0.6
0.2
...
dv=64

⚠️ Common Misconception

Q, K, V are not three different interpretations of the same embedding. They are three completely different projections through three separate learned weight matrices. The same input token produces different Q, K, and V vectors.

What Do These Projections Learn?

During training, the model learns weight matrices that create useful query/key/value spaces:

Research has shown that attention heads often specialize. Some heads learn syntactic relationships (subject-verb), others learn positional patterns (attend to previous token), and others learn semantic relationships (entity coreference).

4. Mathematical Foundations

The Attention Formula

The complete scaled dot-product attention formula is:

Attention(Q, K, V) = softmax( Q KT / √dk ) · V (1)
The fundamental attention equation

Let's break down each component:

Step 1: Compute Attention Scores (Q KT)

The matrix multiplication Q KT computes the dot product between every query and every key:

(Q KT)ij = qi · kj = Σm qim · kjm
Score between position i (query) and position j (key)

The result is an n×n matrix where entry (i,j) represents how much position i should attend to position j. Higher values indicate higher relevance.

Step 2: Scale by √dk

We divide by √dk to counteract the effect of large dot products:

🔬 Why √dk? A Statistical Argument

Assume q and k are vectors with components drawn independently from a distribution with mean 0 and variance 1. Their dot product is:

q · k = Σi=1dk qi ki

Since each qiki has variance 1 (product of two unit-variance variables), the sum has variance dk (sum of dk independent terms).

So q · k has standard deviation √dk. As dk grows, dot products become larger in magnitude, pushing softmax into saturation regions where gradients nearly vanish.

Dividing by √dk normalizes the variance back to ~1, keeping gradients healthy regardless of dimension.

Step 3: Softmax Normalization

We apply softmax along the key dimension (each row):

αij = exp(sij) / Σm exp(sim)
Where sij = (Q KT)ij / √dk

Properties of the resulting attention weights α:

Step 4: Weighted Sum of Values

Finally, we compute a weighted combination of value vectors:

outputi = Σj αij · vj
Output for position i is a weighted sum of all value vectors

In matrix form: Output = α · V, where α is the n×n attention weight matrix and V is n×dv.

💡 The Beautiful Property

The output for each position is a convex combination of all value vectors (weights sum to 1, all non-negative). This means the output lies within the "convex hull" of the value vectors - attention can select, interpolate, or mix information, but cannot extrapolate beyond what's in the values.

5. Interactive Visualization

Attention Matrix Explorer Interactive

Click on a token to highlight its attention pattern (which tokens it attends to):

The
cat
sat
on
the
mat

Reading the matrix: Row = Query (the position doing the looking), Column = Key (the position being looked at).

This shows causal attention - each position can only attend to itself and previous positions (future positions are masked with "-").

Understanding the Attention Pattern

The visualization shows a simulated attention pattern that a trained language model might produce. Notice:

6. PyTorch Implementation

Let's implement self-attention from scratch. We'll build up from a minimal version to a more complete implementation.

Minimal Implementation

🐍 Minimal Self-Attention
PyTorch
import torch
import torch.nn.functional as F
import math

def attention(Q, K, V, mask=None):
    """
    Minimal scaled dot-product attention.

    Args:
        Q: Queries (batch, seq_len, d_k)
        K: Keys (batch, seq_len, d_k)
        V: Values (batch, seq_len, d_v)
        mask: Optional mask (batch, seq_len, seq_len) or (1, seq_len, seq_len)

    Returns:
        output: Attended values (batch, seq_len, d_v)
        weights: Attention weights (batch, seq_len, seq_len)
    """
    d_k = Q.size(-1)

    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # Apply mask (if provided)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # Softmax to get attention weights
    weights = F.softmax(scores, dim=-1)

    # Weighted sum of values
    output = torch.matmul(weights, V)

    return output, weights

Complete Self-Attention Module

🐍 Full Self-Attention Class
PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple

class SelfAttention(nn.Module):
    """
    Scaled Dot-Product Self-Attention.

    This is single-head attention. For multi-head, see the next lesson.

    Args:
        d_model: Input embedding dimension
        d_k: Key/Query dimension (defaults to d_model)
        d_v: Value dimension (defaults to d_k)
        dropout: Dropout probability on attention weights
    """

    def __init__(
        self,
        d_model: int,
        d_k: Optional[int] = None,
        d_v: Optional[int] = None,
        dropout: float = 0.0
    ):
        super().__init__()

        self.d_k = d_k if d_k is not None else d_model
        self.d_v = d_v if d_v is not None else self.d_k
        self.scale = math.sqrt(self.d_k)

        # Linear projections
        # Note: bias=False is common in modern LLMs (LLaMA, etc.)
        self.W_q = nn.Linear(d_model, self.d_k, bias=False)
        self.W_k = nn.Linear(d_model, self.d_k, bias=False)
        self.W_v = nn.Linear(d_model, self.d_v, bias=False)

        # Output projection (projects back to d_model if d_v != d_model)
        self.W_o = nn.Linear(self.d_v, d_model, bias=False)

        # Dropout on attention weights (regularization)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        return_weights: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass.

        Args:
            x: Input tensor (batch, seq_len, d_model)
            mask: Attention mask (batch, 1, seq_len) or (batch, seq_len, seq_len)
                  1 = attend, 0 = mask out
            return_weights: Whether to return attention weights

        Returns:
            output: (batch, seq_len, d_model)
            weights: (batch, seq_len, seq_len) if return_weights=True
        """
        batch_size, seq_len, _ = x.shape

        # Project to Q, K, V
        Q = self.W_q(x)  # (batch, seq_len, d_k)
        K = self.W_k(x)  # (batch, seq_len, d_k)
        V = self.W_v(x)  # (batch, seq_len, d_v)

        # Compute attention scores
        # (batch, seq_len, d_k) @ (batch, d_k, seq_len) -> (batch, seq_len, seq_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # Apply mask
        if mask is not None:
            # Use -inf so softmax gives 0 probability
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Softmax over keys (last dimension)
        weights = F.softmax(scores, dim=-1)

        # Apply dropout to attention weights
        weights = self.dropout(weights)

        # Weighted sum of values
        # (batch, seq_len, seq_len) @ (batch, seq_len, d_v) -> (batch, seq_len, d_v)
        attended = torch.matmul(weights, V)

        # Project back to d_model
        output = self.W_o(attended)

        if return_weights:
            return output, weights
        return output, None


# ===== Example Usage =====
if __name__ == "__main__":
    # Hyperparameters
    batch_size = 2
    seq_len = 10
    d_model = 512

    # Create random input (simulating token embeddings)
    x = torch.randn(batch_size, seq_len, d_model)

    # Create attention layer
    attn = SelfAttention(d_model=d_model, dropout=0.1)

    # Forward pass (no mask = full attention)
    output, weights = attn(x, return_weights=True)

    print(f"Input shape:  {x.shape}")        # (2, 10, 512)
    print(f"Output shape: {output.shape}")   # (2, 10, 512)
    print(f"Weights shape: {weights.shape}") # (2, 10, 10)

    # Verify attention weights sum to 1
    print(f"Row sums: {weights[0].sum(dim=-1)}")  # All ~1.0

💡 Code Walkthrough

  • Lines 31-34: Four linear projections. Three for Q, K, V and one output projection (Wo) to map back to d_model.
  • Line 37: Dropout on attention weights is a form of regularization - randomly zeroing out some attention connections during training.
  • Line 65: The key operation - batch matrix multiplication computes all pairwise attention scores in one operation.
  • Line 69: Masking uses -inf so softmax produces 0 (e-∞ = 0), completely blocking information flow.

7. Causal Masking for Autoregressive Models

Why Causal Masking?

In autoregressive language models (GPT, LLaMA, Claude), we generate text left-to-right, predicting each token based only on previous tokens. During training, we need to enforce this constraint so the model can't "cheat" by looking at future tokens.

Causal mask (also called "look-ahead mask"): Position i can only attend to positions 0, 1, ..., i.

🐍 Causal Mask Implementation
PyTorch
def create_causal_mask(seq_len: int, device: torch.device = None) -> torch.Tensor:
    """
    Creates a causal (lower triangular) attention mask.

    Args:
        seq_len: Sequence length
        device: Device to create tensor on

    Returns:
        mask: (1, seq_len, seq_len) tensor
              1 = can attend, 0 = cannot attend (will become -inf)
    """
    # torch.tril creates lower triangular matrix
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    # Add batch dimension for broadcasting
    return mask.unsqueeze(0)  # (1, seq_len, seq_len)


# Example: seq_len = 5
mask = create_causal_mask(5)
print(mask[0])
# tensor([[1., 0., 0., 0., 0.],   <- token 0 sees only itself
#         [1., 1., 0., 0., 0.],   <- token 1 sees tokens 0, 1
#         [1., 1., 1., 0., 0.],   <- token 2 sees tokens 0, 1, 2
#         [1., 1., 1., 1., 0.],   <- token 3 sees tokens 0, 1, 2, 3
#         [1., 1., 1., 1., 1.]])  <- token 4 sees all tokens


# Combining causal mask with padding mask
def create_combined_mask(
    seq_len: int,
    padding_mask: torch.Tensor = None  # (batch, seq_len), 1=real token, 0=padding
) -> torch.Tensor:
    """
    Creates mask that combines causal constraint with padding.

    Returns:
        mask: (batch, seq_len, seq_len)
    """
    # Start with causal mask
    causal = create_causal_mask(seq_len, device=padding_mask.device)

    if padding_mask is not None:
        # Expand padding mask: don't attend to padding tokens
        # (batch, seq_len) -> (batch, 1, seq_len)
        padding = padding_mask.unsqueeze(1)
        # Combine: must satisfy BOTH causal AND not-padding
        mask = causal * padding

    return mask

Bidirectional vs Causal Attention

Aspect Bidirectional (BERT-style) Causal (GPT-style)
Mask No mask (or only padding mask) Lower triangular mask
Attention Each position sees all positions Each position sees only past + self
Use case Understanding (classification, embeddings) Generation (language models)
Training Masked language modeling Next token prediction

8. Computational Complexity Analysis

Time Complexity

For a sequence of length n with dimension d:

Operation Complexity Notes
Q, K, V projections O(n · d²) Linear in sequence length
Q KT (attention scores) O(n² · d) Quadratic in sequence length!
Softmax O(n²) Applied to n×n matrix
Attention · V O(n² · d) Quadratic in sequence length!

Total: O(n² · d) - The quadratic dependence on n is the main bottleneck for long sequences.

Memory Complexity

The attention weight matrix α ∈ ℝn×n must be stored (at least temporarily):

⚠️ The Long-Context Challenge

This quadratic complexity is why early transformers were limited to ~512-2048 tokens. Modern techniques that enable 100k+ token contexts include:

  • Flash Attention: Computes attention without materializing the full n×n matrix
  • Sparse Attention: Only compute attention for a subset of positions
  • Linear Attention: Approximations that achieve O(n) complexity

We'll cover these in detail in Module 6 (Inference Optimization).

Comparison with RNNs

Aspect Self-Attention RNN
Time per layer O(n² · d) O(n · d²)
Sequential operations O(1) - fully parallel O(n) - inherently sequential
Max path length O(1) - any position to any position O(n) - must pass through all states

For typical sequence lengths (hundreds to low thousands), self-attention's parallelizability makes it much faster to train on GPUs despite higher theoretical complexity.

9. Attention Variants (Preview)

The basic self-attention we've covered is the foundation, but modern LLMs use several important variants:

Multi-Head Attention

Instead of one attention operation, run h parallel attention "heads" with smaller dimensions (dk = d/h), then concatenate results. This allows the model to jointly attend to information from different representation subspaces. → Next lesson

Grouped Query Attention (GQA)

Used in Llama 2, Mistral. Share K and V projections across groups of heads to reduce memory (KV cache). We'll cover this in Module 6.

Multi-Query Attention (MQA)

Extreme version of GQA: all heads share the same K and V. Faster inference but potentially lower quality.

Multi-head Latent Attention (MLA)

Used in DeepSeek-V2/V3. Uses low-rank projection to compress KV cache. Covered in Module 7.

Flash Attention

Not a different attention formula, but a memory-efficient implementation that avoids materializing the full attention matrix. Critical for long contexts. Covered in Module 6.

10. Important Questions & Answers

These are fundamental questions that deepen understanding of self-attention:

Q: Why don't we just use the same vector for Q, K, and V (i.e., skip the projections)?

If Q = K = V = X (the input), then the attention scores would simply be X XT, which computes raw similarity between token embeddings. This has several problems:

  • No learnable parameters: The model couldn't learn what constitutes "relevance" for the task
  • Same space for matching and content: What makes tokens similar (for matching) might differ from what information should be aggregated (content)
  • Limited expressivity: Can only express similarities already present in the embedding space

The separate projections let the model learn task-specific notions of "what to look for" (Q), "what to advertise" (K), and "what to share" (V).

Q: Can attention weights be negative? Can they be greater than 1?

No and no. After softmax, all attention weights are in the range [0, 1] and sum to 1 along each row. This means:

  • The output is always a convex combination of value vectors
  • Attention can select (put all weight on one position), average (spread weight evenly), or anything in between
  • Attention cannot extrapolate - the output lies within the convex hull of the values

Some research has explored allowing negative attention (e.g., "differential attention"), but standard transformers use non-negative weights.

Q: What happens to the attention pattern for very long sequences?

Several issues emerge:

  • Dilution: Softmax normalizes over all positions. With many positions, attention weights can become very small ("attention sink" phenomenon)
  • Computational cost: O(n²) becomes prohibitive. 100k tokens = 10 billion attention score computations per layer
  • Memory: The n×n attention matrix can't fit in GPU memory for very long sequences

Modern solutions include positional extrapolation (RoPE scaling), sliding window attention, and Flash Attention (memory-efficient implementation).

Q: Why is dropout applied to attention weights? Isn't that throwing away information?

Attention dropout is a regularization technique that helps with:

  • Preventing over-reliance: Without dropout, the model might learn to always attend to specific positions (like the first token), reducing generalization
  • Robustness: Forces the model to have multiple "paths" for information flow
  • Implicit ensemble: Like all dropout, it trains an ensemble of sub-networks

Note: Attention dropout is typically only applied during training. At inference, all attention weights are used.

Q: How does the model learn which positions are relevant without explicit supervision?

Attention patterns are learned indirectly through the language modeling objective. The model learns:

  • Predicting the next token requires understanding relationships between previous tokens
  • Through backpropagation, WQ, WK, WV are updated to create attention patterns that help minimize prediction error
  • If attending to the subject helps predict the verb, the model learns Q/K projections that give high scores between verbs and their subjects

This is emergent behavior - we never explicitly tell the model to attend to subjects. It discovers this because it helps the objective.

Q: Why do modern LLMs use no bias in the linear projections (bias=False)?

Several reasons, though empirical results vary:

  • Parameter efficiency: Small reduction in parameters (especially significant at scale)
  • Symmetry: Without bias, the projections are purely linear transformations of the input
  • RMSNorm compatibility: Many modern LLMs use RMSNorm which centers activations differently
  • Empirical results: Papers like LLaMA showed no bias works as well or better

The original transformer used bias=True. The shift to bias=False is a design choice in recent models (LLaMA, Mistral, etc.).

Q: What's the relationship between attention and memory in neural networks?

Self-attention can be viewed as a form of content-addressable memory:

  • The keys (K) form memory "addresses" or "indices"
  • The values (V) form memory "contents"
  • The query (Q) performs a "soft lookup" - matching against all addresses simultaneously
  • Unlike hard memory (which returns one item), attention returns a weighted blend

This perspective connects transformers to memory-augmented neural networks and explains why they're good at tasks requiring retrieval (like question answering).

Q: Is self-attention permutation equivariant? What does that mean?

Yes, self-attention is permutation equivariant: if you shuffle the input tokens, the output is shuffled in the same way.

Mathematically: if π is a permutation, then Attention(π(X)) = π(Attention(X)).

This is a problem! The model has no inherent sense of order. "The cat sat on the mat" and "mat the on sat cat The" would produce the same attention patterns (just reordered).

This is why transformers need positional encodings - to break this symmetry and inject position information. We'll cover this in Lesson 3.

11. Your Questions

As you learn, add your questions here. These will be incorporated into the tutorial for future reference.

No questions yet. As you study this material, ask questions and they'll be documented here!

This section will grow as we discuss the material together.

12. Comprehension Quiz

1. Why do we divide by √dk in the attention formula?

To make the computation faster by using smaller numbers
To ensure the output values are in the range [-1, 1]
To prevent dot products from becoming too large, which would cause softmax gradients to vanish
To make attention weights sum to 1
Correct! Without scaling, dot products grow with dk (variance ~ dk). Large dot products push softmax into saturation regions where outputs are nearly 0 or 1, causing vanishing gradients. Dividing by √dk normalizes variance back to ~1.

2. What is the computational complexity of self-attention with respect to sequence length n?

O(n) - linear
O(n²) - quadratic
O(n log n) - linearithmic
O(n³) - cubic
Correct! The Q·KT operation creates an n×n attention matrix, requiring O(n²) computations. This quadratic complexity is the main bottleneck for long-context models.

3. In causal (autoregressive) self-attention, position i can attend to:

Positions 0, 1, ..., i (itself and all previous)
All positions in the sequence
Only position i-1 (the immediately previous token)
Positions i, i+1, ..., n (itself and all future)
Correct! Causal masking ensures each position only sees itself and past positions, matching how generation works at inference time (predicting next token without seeing future).

4. After softmax, the attention weights for each query position:

Can be negative, ranging from -1 to 1
Are all equal (uniform distribution)
Are integers representing position indices
Are non-negative and sum to 1 (a probability distribution)
Correct! Softmax outputs form a probability distribution: all values ≥ 0 and sum to 1. This means the output is a convex combination of values.

5. Why are Q, K, V separate projections instead of using the same vector?

To reduce computational cost
To make the model deterministic
To let the model learn different representations for "what to look for", "what to match on", and "what to retrieve"
It's just a historical convention with no practical benefit
Correct! Separate projections allow the model to learn that what makes tokens similar for matching (Q·K) might differ from what content should be aggregated (V). This dramatically increases expressivity.

6. Self-attention is permutation equivariant. This means:

The output doesn't depend on the input order
Shuffling inputs produces equivalently shuffled outputs - no inherent position awareness
The model learns to permute tokens automatically
Attention weights are always symmetric
Correct! If you shuffle input tokens, outputs shuffle the same way. Self-attention has no inherent sense of position - this is why transformers need positional encodings to know token order.
← Previous Next: Multi-Head Attention →