1. Historical Context & Motivation
The Problem with Sequential Models
Before transformers, the dominant architectures for sequence processing were Recurrent Neural Networks (RNNs) and their variants (LSTMs, GRUs). These models processed sequences one token at a time, maintaining a hidden state that accumulated information:
This sequential dependency created several critical problems:
- Vanishing/Exploding Gradients: Information had to pass through many time steps. Gradients either vanished (making early tokens hard to learn from) or exploded. LSTMs and GRUs mitigated but didn't solve this.
- Limited Long-Range Dependencies: Even LSTMs struggled to connect information across hundreds of tokens. The hidden state acted as a "bottleneck" - all context had to fit in a fixed-size vector.
- No Parallelization: Because ht depends on ht-1, you couldn't compute positions in parallel during training. This made RNNs extremely slow to train on modern GPUs.
The Attention Revolution
Attention was first introduced in 2014 for machine translation (Bahdanau et al.). Instead of compressing all source information into a single vector, the decoder could "look back" at all encoder states and decide which ones were relevant.
The 2017 paper "Attention Is All You Need" (Vaswani et al.) took this further: what if we used ONLY attention, with no recurrence at all? This gave birth to the Transformer architecture.
💡 The Fundamental Insight
Self-attention allows every position in a sequence to directly attend to every other position in a single operation. Information doesn't need to "travel" through intermediate states - position 1 can directly influence position 100 with equal ease.
2. The Core Intuition
Attention as Dynamic, Content-Based Routing
Think of self-attention as a sophisticated information routing system. For each position in the sequence, we ask: "Given what I'm trying to compute here, which other positions in this sequence contain relevant information, and how should I combine that information?"
Unlike convolutions (which have fixed receptive fields) or RNNs (which have fixed routing through time), attention routing is:
- Dynamic: The routing pattern changes based on the actual content of the input, not just positions
- Learned: The model learns what constitutes "relevance" during training
- All-to-All: Every position can potentially attend to every other position
The Database Analogy
A useful mental model is to think of attention like a fuzzy database lookup:
Query: "What am I looking for?"
Each position generates a query vector that encodes what information it needs. In a sentence like "The cat sat on the mat", when processing "sat", the query might encode "I need to find the subject of this action."
Key: "What do I contain?"
Each position also generates a key vector that advertises what information it holds. The key for "cat" might encode "I am a noun, likely a subject." The key for "the" might encode "I am a determiner."
Matching: Query · Key = Relevance Score
We compute dot products between the query and all keys. High dot product = high similarity = high relevance. The query for "sat" will have high similarity with the key for "cat" (its subject).
Value: "Here's my actual content"
Each position has a value vector containing its actual information to share. After determining relevance scores, we take a weighted sum of all values, where weights come from the relevance scores.
🔬 Why Separate Keys and Values?
You might wonder: why not just use the same vector for keys and values? The separation allows the model to independently learn:
- What to match on (keys) - the "address" or "index" properties
- What to retrieve (values) - the actual "content" to aggregate
This is similar to how a hash table has separate keys (for lookup) and values (for storage). A word's syntactic role (encoded in key) might differ from its semantic content (encoded in value).
3. Query, Key, Value: Deep Dive
The Linear Projections
Given an input sequence X ∈ ℝn×d (n tokens, each with d-dimensional embedding), we create Q, K, V through learned linear transformations:
K = X · WK ∈ ℝn×dk
V = X · WV ∈ ℝn×dv
Dimension Choices
In the original transformer and most modern LLMs:
- dk = dv = d / h, where h is the number of attention heads
- For example, with d=768 and h=12 heads, each head has dk=dv=64
- Some architectures (like Multi-Query Attention) use different dimensions for K and V to reduce memory
Interactive: QKV Vectors
For the token "cat", we project the embedding through three different learned matrices:
⚠️ Common Misconception
Q, K, V are not three different interpretations of the same embedding. They are three completely different projections through three separate learned weight matrices. The same input token produces different Q, K, and V vectors.
What Do These Projections Learn?
During training, the model learns weight matrices that create useful query/key/value spaces:
- WQ learns to project tokens into a "question space" - encoding what information each position needs
- WK learns to project tokens into an "answer space" - encoding what information each position offers, in a way that matches relevant queries
- WV learns to project tokens into a "content space" - encoding the actual information to be aggregated
Research has shown that attention heads often specialize. Some heads learn syntactic relationships (subject-verb), others learn positional patterns (attend to previous token), and others learn semantic relationships (entity coreference).
4. Mathematical Foundations
The Attention Formula
The complete scaled dot-product attention formula is:
Let's break down each component:
Step 1: Compute Attention Scores (Q KT)
The matrix multiplication Q KT computes the dot product between every query and every key:
The result is an n×n matrix where entry (i,j) represents how much position i should attend to position j. Higher values indicate higher relevance.
Step 2: Scale by √dk
We divide by √dk to counteract the effect of large dot products:
🔬 Why √dk? A Statistical Argument
Assume q and k are vectors with components drawn independently from a distribution with mean 0 and variance 1. Their dot product is:
q · k = Σi=1dk qi ki
Since each qiki has variance 1 (product of two unit-variance variables), the sum has variance dk (sum of dk independent terms).
So q · k has standard deviation √dk. As dk grows, dot products become larger in magnitude, pushing softmax into saturation regions where gradients nearly vanish.
Dividing by √dk normalizes the variance back to ~1, keeping gradients healthy regardless of dimension.
Step 3: Softmax Normalization
We apply softmax along the key dimension (each row):
Properties of the resulting attention weights α:
- Non-negative: αij ≥ 0 for all i, j
- Normalized: Σj αij = 1 for each row i
- Interpretable: αij can be viewed as "probability that position i attends to position j"
Step 4: Weighted Sum of Values
Finally, we compute a weighted combination of value vectors:
In matrix form: Output = α · V, where α is the n×n attention weight matrix and V is n×dv.
💡 The Beautiful Property
The output for each position is a convex combination of all value vectors (weights sum to 1, all non-negative). This means the output lies within the "convex hull" of the value vectors - attention can select, interpolate, or mix information, but cannot extrapolate beyond what's in the values.
5. Interactive Visualization
Click on a token to highlight its attention pattern (which tokens it attends to):
Reading the matrix: Row = Query (the position doing the looking), Column = Key (the position being looked at).
This shows causal attention - each position can only attend to itself and previous positions (future positions are masked with "-").
Understanding the Attention Pattern
The visualization shows a simulated attention pattern that a trained language model might produce. Notice:
- "sat" strongly attends to "cat" - the verb connects to its subject
- "the" (second) attends to "The" (first) - learning that these are the same word in different positions
- Each position attends to itself - often important for preserving its own information
- The lower triangle pattern - this is the causal mask, preventing positions from seeing the future
6. PyTorch Implementation
Let's implement self-attention from scratch. We'll build up from a minimal version to a more complete implementation.
Minimal Implementation
import torch
import torch.nn.functional as F
import math
def attention(Q, K, V, mask=None):
"""
Minimal scaled dot-product attention.
Args:
Q: Queries (batch, seq_len, d_k)
K: Keys (batch, seq_len, d_k)
V: Values (batch, seq_len, d_v)
mask: Optional mask (batch, seq_len, seq_len) or (1, seq_len, seq_len)
Returns:
output: Attended values (batch, seq_len, d_v)
weights: Attention weights (batch, seq_len, seq_len)
"""
d_k = Q.size(-1)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask (if provided)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Softmax to get attention weights
weights = F.softmax(scores, dim=-1)
# Weighted sum of values
output = torch.matmul(weights, V)
return output, weights
Complete Self-Attention Module
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
class SelfAttention(nn.Module):
"""
Scaled Dot-Product Self-Attention.
This is single-head attention. For multi-head, see the next lesson.
Args:
d_model: Input embedding dimension
d_k: Key/Query dimension (defaults to d_model)
d_v: Value dimension (defaults to d_k)
dropout: Dropout probability on attention weights
"""
def __init__(
self,
d_model: int,
d_k: Optional[int] = None,
d_v: Optional[int] = None,
dropout: float = 0.0
):
super().__init__()
self.d_k = d_k if d_k is not None else d_model
self.d_v = d_v if d_v is not None else self.d_k
self.scale = math.sqrt(self.d_k)
# Linear projections
# Note: bias=False is common in modern LLMs (LLaMA, etc.)
self.W_q = nn.Linear(d_model, self.d_k, bias=False)
self.W_k = nn.Linear(d_model, self.d_k, bias=False)
self.W_v = nn.Linear(d_model, self.d_v, bias=False)
# Output projection (projects back to d_model if d_v != d_model)
self.W_o = nn.Linear(self.d_v, d_model, bias=False)
# Dropout on attention weights (regularization)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
return_weights: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass.
Args:
x: Input tensor (batch, seq_len, d_model)
mask: Attention mask (batch, 1, seq_len) or (batch, seq_len, seq_len)
1 = attend, 0 = mask out
return_weights: Whether to return attention weights
Returns:
output: (batch, seq_len, d_model)
weights: (batch, seq_len, seq_len) if return_weights=True
"""
batch_size, seq_len, _ = x.shape
# Project to Q, K, V
Q = self.W_q(x) # (batch, seq_len, d_k)
K = self.W_k(x) # (batch, seq_len, d_k)
V = self.W_v(x) # (batch, seq_len, d_v)
# Compute attention scores
# (batch, seq_len, d_k) @ (batch, d_k, seq_len) -> (batch, seq_len, seq_len)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# Apply mask
if mask is not None:
# Use -inf so softmax gives 0 probability
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax over keys (last dimension)
weights = F.softmax(scores, dim=-1)
# Apply dropout to attention weights
weights = self.dropout(weights)
# Weighted sum of values
# (batch, seq_len, seq_len) @ (batch, seq_len, d_v) -> (batch, seq_len, d_v)
attended = torch.matmul(weights, V)
# Project back to d_model
output = self.W_o(attended)
if return_weights:
return output, weights
return output, None
# ===== Example Usage =====
if __name__ == "__main__":
# Hyperparameters
batch_size = 2
seq_len = 10
d_model = 512
# Create random input (simulating token embeddings)
x = torch.randn(batch_size, seq_len, d_model)
# Create attention layer
attn = SelfAttention(d_model=d_model, dropout=0.1)
# Forward pass (no mask = full attention)
output, weights = attn(x, return_weights=True)
print(f"Input shape: {x.shape}") # (2, 10, 512)
print(f"Output shape: {output.shape}") # (2, 10, 512)
print(f"Weights shape: {weights.shape}") # (2, 10, 10)
# Verify attention weights sum to 1
print(f"Row sums: {weights[0].sum(dim=-1)}") # All ~1.0
💡 Code Walkthrough
- Lines 31-34: Four linear projections. Three for Q, K, V and one output projection (Wo) to map back to d_model.
- Line 37: Dropout on attention weights is a form of regularization - randomly zeroing out some attention connections during training.
- Line 65: The key operation - batch matrix multiplication computes all pairwise attention scores in one operation.
- Line 69: Masking uses -inf so softmax produces 0 (e-∞ = 0), completely blocking information flow.
7. Causal Masking for Autoregressive Models
Why Causal Masking?
In autoregressive language models (GPT, LLaMA, Claude), we generate text left-to-right, predicting each token based only on previous tokens. During training, we need to enforce this constraint so the model can't "cheat" by looking at future tokens.
Causal mask (also called "look-ahead mask"): Position i can only attend to positions 0, 1, ..., i.
def create_causal_mask(seq_len: int, device: torch.device = None) -> torch.Tensor:
"""
Creates a causal (lower triangular) attention mask.
Args:
seq_len: Sequence length
device: Device to create tensor on
Returns:
mask: (1, seq_len, seq_len) tensor
1 = can attend, 0 = cannot attend (will become -inf)
"""
# torch.tril creates lower triangular matrix
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
# Add batch dimension for broadcasting
return mask.unsqueeze(0) # (1, seq_len, seq_len)
# Example: seq_len = 5
mask = create_causal_mask(5)
print(mask[0])
# tensor([[1., 0., 0., 0., 0.], <- token 0 sees only itself
# [1., 1., 0., 0., 0.], <- token 1 sees tokens 0, 1
# [1., 1., 1., 0., 0.], <- token 2 sees tokens 0, 1, 2
# [1., 1., 1., 1., 0.], <- token 3 sees tokens 0, 1, 2, 3
# [1., 1., 1., 1., 1.]]) <- token 4 sees all tokens
# Combining causal mask with padding mask
def create_combined_mask(
seq_len: int,
padding_mask: torch.Tensor = None # (batch, seq_len), 1=real token, 0=padding
) -> torch.Tensor:
"""
Creates mask that combines causal constraint with padding.
Returns:
mask: (batch, seq_len, seq_len)
"""
# Start with causal mask
causal = create_causal_mask(seq_len, device=padding_mask.device)
if padding_mask is not None:
# Expand padding mask: don't attend to padding tokens
# (batch, seq_len) -> (batch, 1, seq_len)
padding = padding_mask.unsqueeze(1)
# Combine: must satisfy BOTH causal AND not-padding
mask = causal * padding
return mask
Bidirectional vs Causal Attention
| Aspect | Bidirectional (BERT-style) | Causal (GPT-style) |
|---|---|---|
| Mask | No mask (or only padding mask) | Lower triangular mask |
| Attention | Each position sees all positions | Each position sees only past + self |
| Use case | Understanding (classification, embeddings) | Generation (language models) |
| Training | Masked language modeling | Next token prediction |
8. Computational Complexity Analysis
Time Complexity
For a sequence of length n with dimension d:
| Operation | Complexity | Notes |
|---|---|---|
| Q, K, V projections | O(n · d²) | Linear in sequence length |
| Q KT (attention scores) | O(n² · d) | Quadratic in sequence length! |
| Softmax | O(n²) | Applied to n×n matrix |
| Attention · V | O(n² · d) | Quadratic in sequence length! |
Total: O(n² · d) - The quadratic dependence on n is the main bottleneck for long sequences.
Memory Complexity
The attention weight matrix α ∈ ℝn×n must be stored (at least temporarily):
- For n = 2048 tokens: 2048² × 4 bytes = 16 MB per attention layer per batch element
- For n = 100,000 tokens: 100,000² × 4 bytes = 40 GB per layer! (Infeasible)
⚠️ The Long-Context Challenge
This quadratic complexity is why early transformers were limited to ~512-2048 tokens. Modern techniques that enable 100k+ token contexts include:
- Flash Attention: Computes attention without materializing the full n×n matrix
- Sparse Attention: Only compute attention for a subset of positions
- Linear Attention: Approximations that achieve O(n) complexity
We'll cover these in detail in Module 6 (Inference Optimization).
Comparison with RNNs
| Aspect | Self-Attention | RNN |
|---|---|---|
| Time per layer | O(n² · d) | O(n · d²) |
| Sequential operations | O(1) - fully parallel | O(n) - inherently sequential |
| Max path length | O(1) - any position to any position | O(n) - must pass through all states |
For typical sequence lengths (hundreds to low thousands), self-attention's parallelizability makes it much faster to train on GPUs despite higher theoretical complexity.
9. Attention Variants (Preview)
The basic self-attention we've covered is the foundation, but modern LLMs use several important variants:
Multi-Head Attention
Instead of one attention operation, run h parallel attention "heads" with smaller dimensions (dk = d/h), then concatenate results. This allows the model to jointly attend to information from different representation subspaces. → Next lesson
Grouped Query Attention (GQA)
Used in Llama 2, Mistral. Share K and V projections across groups of heads to reduce memory (KV cache). We'll cover this in Module 6.
Multi-Query Attention (MQA)
Extreme version of GQA: all heads share the same K and V. Faster inference but potentially lower quality.
Multi-head Latent Attention (MLA)
Used in DeepSeek-V2/V3. Uses low-rank projection to compress KV cache. Covered in Module 7.
Flash Attention
Not a different attention formula, but a memory-efficient implementation that avoids materializing the full attention matrix. Critical for long contexts. Covered in Module 6.
10. Important Questions & Answers
These are fundamental questions that deepen understanding of self-attention:
If Q = K = V = X (the input), then the attention scores would simply be X XT, which computes raw similarity between token embeddings. This has several problems:
- No learnable parameters: The model couldn't learn what constitutes "relevance" for the task
- Same space for matching and content: What makes tokens similar (for matching) might differ from what information should be aggregated (content)
- Limited expressivity: Can only express similarities already present in the embedding space
The separate projections let the model learn task-specific notions of "what to look for" (Q), "what to advertise" (K), and "what to share" (V).
No and no. After softmax, all attention weights are in the range [0, 1] and sum to 1 along each row. This means:
- The output is always a convex combination of value vectors
- Attention can select (put all weight on one position), average (spread weight evenly), or anything in between
- Attention cannot extrapolate - the output lies within the convex hull of the values
Some research has explored allowing negative attention (e.g., "differential attention"), but standard transformers use non-negative weights.
Several issues emerge:
- Dilution: Softmax normalizes over all positions. With many positions, attention weights can become very small ("attention sink" phenomenon)
- Computational cost: O(n²) becomes prohibitive. 100k tokens = 10 billion attention score computations per layer
- Memory: The n×n attention matrix can't fit in GPU memory for very long sequences
Modern solutions include positional extrapolation (RoPE scaling), sliding window attention, and Flash Attention (memory-efficient implementation).
Attention dropout is a regularization technique that helps with:
- Preventing over-reliance: Without dropout, the model might learn to always attend to specific positions (like the first token), reducing generalization
- Robustness: Forces the model to have multiple "paths" for information flow
- Implicit ensemble: Like all dropout, it trains an ensemble of sub-networks
Note: Attention dropout is typically only applied during training. At inference, all attention weights are used.
Attention patterns are learned indirectly through the language modeling objective. The model learns:
- Predicting the next token requires understanding relationships between previous tokens
- Through backpropagation, WQ, WK, WV are updated to create attention patterns that help minimize prediction error
- If attending to the subject helps predict the verb, the model learns Q/K projections that give high scores between verbs and their subjects
This is emergent behavior - we never explicitly tell the model to attend to subjects. It discovers this because it helps the objective.
Several reasons, though empirical results vary:
- Parameter efficiency: Small reduction in parameters (especially significant at scale)
- Symmetry: Without bias, the projections are purely linear transformations of the input
- RMSNorm compatibility: Many modern LLMs use RMSNorm which centers activations differently
- Empirical results: Papers like LLaMA showed no bias works as well or better
The original transformer used bias=True. The shift to bias=False is a design choice in recent models (LLaMA, Mistral, etc.).
Self-attention can be viewed as a form of content-addressable memory:
- The keys (K) form memory "addresses" or "indices"
- The values (V) form memory "contents"
- The query (Q) performs a "soft lookup" - matching against all addresses simultaneously
- Unlike hard memory (which returns one item), attention returns a weighted blend
This perspective connects transformers to memory-augmented neural networks and explains why they're good at tasks requiring retrieval (like question answering).
Yes, self-attention is permutation equivariant: if you shuffle the input tokens, the output is shuffled in the same way.
Mathematically: if π is a permutation, then Attention(π(X)) = π(Attention(X)).
This is a problem! The model has no inherent sense of order. "The cat sat on the mat" and "mat the on sat cat The" would produce the same attention patterns (just reordered).
This is why transformers need positional encodings - to break this symmetry and inject position information. We'll cover this in Lesson 3.
11. Your Questions
As you learn, add your questions here. These will be incorporated into the tutorial for future reference.
No questions yet. As you study this material, ask questions and they'll be documented here!
This section will grow as we discuss the material together.