Vision Transformers

MathJax example


I figure the best way to understand Vision Transformers is to trace the "concrete" tensor shape that runs throughout.

Setup and Notation

Symbol Value Meaning
\(H, W, C\) \(224, 224, 3\) Image height, width, channels
\(P\) \(16\) Patch size
\(N\) \(196\) Number of patches (\(14 \times 14\))
\(D\) \(768\) Embedding dimension
\(d_k\) \(64\) Attention head dimension
\(L\) \(1\) Number of encoder blocks
\(K\) \(10\) Number of output classes



Step 1 — Patch Extraction

The image \(I \in \mathbb{R}^{224 \times 224 \times 3}\) is divided into a grid of non-overlapping \(16 \times 16\) patches. Each patch is flattened into a vector:
\[\mathbf{x}_i \in \mathbb{R}^{P^2 \cdot C} = \mathbb{R}^{768} \qquad i = 1, \dots, 196\]
Stacked into a matrix:
\[X_{\text{patches}} \in \mathbb{R}^{196 \times 768}\]
Shape: \([196,\ 768]\)

Step 2 — Linear Patch Embedding

Raw pixel vectors live in a space whose axes are individual pixel intensities. This is a poor space for reasoning — there is no semantic structure along those axes. The embedding projects each patch into a learned space \(\mathbb{R}^D\) where geometrically meaningful relationships can emerge.
\[Z_{\text{patches}} = X_{\text{patches}}\, W_E + \mathbf{b}_E \qquad W_E \in \mathbb{R}^{768 \times 768},\ \mathbf{b}_E \in \mathbb{R}^{768}\]
The analogy for a physicist: choosing the right basis for a Hamiltonian. Working in the pixel basis is like staying in the position basis when the energy eigenbasis reveals all the structure. \(W_E\) is learned — the training process finds the projection.
Shape: \([196,\ 768]\)

Step 3 — CLS Token

A single learnable vector \(\mathbf{z}_{\text{cls}} \in \mathbb{R}^{768}\) is prepended to the sequence. It carries no pixel information — it is a trained parameter, initialized randomly and updated by backpropagation. By the end of the encoder, it will have aggregated global context from all other tokens through attention. For classification, it is the only token used.
\[Z = \left[\mathbf{z}_{\text{cls}}\ ;\ Z_{\text{patches}}\right] \in \mathbb{R}^{197 \times 768}\]
Shape: \([197,\ 768]\)

Step 4 — Positional Encoding

The attention operation (defined in Step 6) is permutation-invariant: shuffling the rows of \(Z\) produces the same result. Spatial position must therefore be injected explicitly. A learned matrix \(P_{\text{pos}} \in \mathbb{R}^{197 \times 768}\) is added element-wise:
\[Z \leftarrow Z + P_{\text{pos}}\]
One position vector per token, each a learned \(D\)-dimensional parameter. After this step every token carries both what it looks like and where it is.
Shape: \([197,\ 768]\) — unchanged

Step 5 — Layer Normalization

Before attention, each token vector is independently normalized across its \(D\) feature dimensions:
\[\text{LN}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sigma + \epsilon} + \beta\]
where \(\mu\) and \(\sigma\) are the scalar mean and standard deviation of \(\mathbf{x} \in \mathbb{R}^D\), and \(\gamma, \beta \in \mathbb{R}^D\) are learned scale and shift parameters.
Applied row-wise: each of the 197 vectors is normalized independently. Without this, as weights update during training the distribution of activations shifts unpredictably, causing dot products \(\mathbf{q} \cdot \mathbf{k}^\top\) to grow in an uncontrolled way. The learned \(\gamma, \beta\) ensure the normalization does not discard representational capacity — the network can undo the normalization if needed.
\[\tilde{Z} = \text{LN}(Z)\]
Shape: \([197,\ 768]\) — unchanged

Step 6 — Self-Attention


Motivation

A convolution at layer \(l\) sees only a \(k \times k\) neighborhood. Relating a patch in the top-left corner to one in the bottom-right requires \(O(\text{image\_size}/k)\) stacked layers — long-range dependencies are expensive and lossy.
Attention asks a different question: given a set of \(N\) vectors, how should each vector update itself by aggregating information from all others, weighted by relevance? For defect detection, the relevance of a patch often depends on context far away — a scratch is only meaningful given knowledge of the surrounding material.

6a — Projecting into Q, K, V

Three learned matrices \(W_Q, W_K, W_V \in \mathbb{R}^{768 \times 64}\) project each token into three separate 64-dimensional subspaces:
\[Q = \tilde{Z}\, W_Q, \qquad K = \tilde{Z}\, W_K, \qquad V = \tilde{Z}\, W_V\]
\[\mathbb{R}^{197 \times 768} \times \mathbb{R}^{768 \times 64} \rightarrow \mathbb{R}^{197 \times 64} \quad \text{each}\]
All three come from the same \(\tilde{Z}\). This is self-attention — each token simultaneously plays three roles. The three projection matrices are learned independently, each specializing for a different role:

  • \(W_Q\) — what am I looking for?
  • \(W_K\) — what do I offer as a match?
  • \(W_V\) — what do I contribute if selected?
    Why not just use the raw embedding \(\tilde{Z}\) directly?
    You could. \(\text{softmax}(\tilde{Z}\tilde{Z}^\top / \sqrt{768}) \cdot \tilde{Z}\) is a valid operation, and if the embedding is good, similar patches will have high dot products and attend to each other. This works. The projections are a generalization of this that removes two specific constraints:
    Constraint 1 — Symmetry. With a single shared matrix \(W\), the attention matrix \(ZWW^\top Z^\top\) is always symmetric: token \(i\) attends to \(j\) exactly as much as \(j\) attends to \(i\). With separate \(W_Q \neq W_K\), attention is asymmetric. In a defect scenario: a background patch may need to look at a defect patch to update its context representation, but the defect patch does not need to reciprocate.
    Constraint 2 — Coupling of comparison and content. With a single \(W\), the subspace used to compare tokens (to decide who attends to whom) is the same subspace used to determine what information flows when a token is attended to. With separate \(W_Q, W_K\) for comparison and \(W_V\) for content, the model can learn to route attention by one criterion (is this a defect?) while passing forward a different set of features (the defect's edge orientation, contrast against substrate, etc.).
    The raw embedding space is optimized for carrying information, not for computing similarity. The projection into \(\mathbb{R}^{64}\) carves out a dedicated comparison subspace. This is the same logic as going from pixel space to embedding space — each projection finds a basis suited to its purpose.
    Shape of Q, K, V: \([197,\ 64]\)

    6b — Attention Scores

    \[A = \frac{Q K^\top}{\sqrt{d_k}} = \frac{Q K^\top}{8}\]
    \[\mathbb{R}^{197 \times 64} \times \mathbb{R}^{64 \times 197} \rightarrow \mathbb{R}^{197 \times 197}\]
    Entry \(A_{ij}\) is the raw compatibility between token \(i\)'s query and token \(j\)'s key.
    The \(\sqrt{d_k}\) scaling is variance stabilization. Each dot product is a sum of \(d_k = 64\) terms each with \(\sim O(1)\) variance, so without scaling the variance grows as \(d_k\). Large magnitudes push softmax toward a one-hot distribution — near-zero gradients, unstable training. Dividing by \(\sqrt{d_k}\) normalizes variance back to \(O(1)\).
    Shape: \([197,\ 197]\)

    6c — Softmax

    \[\hat{A}_{ij} = \frac{e^{A_{ij}}}{\sum_{k=1}^{197} e^{A_{ik}}}\]
    Row-wise: each row sums to 1. Row \(i\) is a probability distribution over all 197 tokens — how much token \(i\) attends to every other token. This matrix is directly visualizable as an attention map.
    Shape: \([197,\ 197]\)

    6d — Weighted Aggregation

    \[\text{SA}(\tilde{Z}) = \hat{A}\, V\]
    \[\mathbb{R}^{197 \times 197} \times \mathbb{R}^{197 \times 64} \rightarrow \mathbb{R}^{197 \times 64}\]
    Row \(i\) of the output is a weighted sum of all value vectors. Each token has now aggregated information from the entire sequence, weighted by learned relevance.
    Shape: \([197,\ 64]\)

    6e — Output Projection

    Project back to dimension \(D\) with \(W_O \in \mathbb{R}^{64 \times 768}\):
    \[\text{MSA}(\tilde{Z}) = \text{SA}(\tilde{Z})\, W_O\]
    Shape: \([197,\ 768]\)

    Step 7 — Residual Connection (post-attention)

    \[Z \leftarrow Z + \text{MSA}(\text{LN}(Z))\]
    \[\mathbb{R}^{197 \times 768} + \mathbb{R}^{197 \times 768} \rightarrow \mathbb{R}^{197 \times 768}\]
    The residual adds the original \(Z\) back to the attention output. Two purposes: (1) a gradient highway — gradients flow directly through the addition without passing through the attention computation; (2) the attention block only needs to learn a correction to the identity, which is an easier optimization target than learning the full transformation from scratch.
    Shape: \([197,\ 768]\)

    Step 8 — Layer Normalization + MLP Block

    A second LayerNorm followed by a position-wise two-layer MLP applied independently to each token:
    \[\text{MLP}(\mathbf{x}) = \text{GELU}(\mathbf{x}\, W_1 + \mathbf{b}_1)\, W_2 + \mathbf{b}_2\]
    \[W_1 \in \mathbb{R}^{768 \times 3072}, \quad W_2 \in \mathbb{R}^{3072 \times 768}\]
    Per token: \(\mathbb{R}^{768} \rightarrow \mathbb{R}^{3072} \rightarrow \mathbb{R}^{768}\). Applied to all 197 tokens simultaneously. Followed by a second residual:
    \[Z \leftarrow Z + \text{MLP}(\text{LN}(Z))\]
    Attention mixes information across tokens. The MLP processes each token independently after that mixing. The two operations are complementary — global routing followed by local refinement.
    Shape: \([197,\ 768]\) — unchanged throughout

    Step 9 — Classification Head

    Extract the CLS token (row 0) and apply a linear classifier:
    \[\hat{\mathbf{y}} = Z[0,\ :]\, W_{\text{cls}} \qquad W_{\text{cls}} \in \mathbb{R}^{768 \times K}\]
    Shape: \([768] \rightarrow [K]\)

    Step 10 — Segmentation Pivot

    Classification uses the CLS token — one label per image, spatial information discarded. Segmentation requires one label per pixel — spatial information must be preserved.
    Discard the CLS token. Keep all patch tokens:
    \[Z_{\text{patches}} = Z[1:,\ :] \in \mathbb{R}^{196 \times 768}\]
    Reshape to a spatial grid:
    \[\mathbb{R}^{196 \times 768} \rightarrow \mathbb{R}^{14 \times 14 \times 768}\]
    This is a spatial feature map — structurally equivalent to a CNN backbone output, but every location has been computed with global context. A lightweight decoder upsamples to pixel resolution:
    \[\mathbb{R}^{14 \times 14 \times 768} \rightarrow \mathbb{R}^{224 \times 224 \times K}\]
    For the masked encoder variant: masking during pretraining forces the model to reconstruct missing patches from context. The learned representations therefore encode *what is contextually normal* for a given location. Defect patches are those whose representations deviate from what the model predicts — the anomaly signal is baked into the representation space itself.

    Complete Shape Trace

    Step Operation Shape
    Input Raw image \([224,\ 224,\ 3]\)
    1 Patch extraction + flatten \([196,\ 768]\)
    2 Linear embedding \(W_E\) \([196,\ 768]\)
    3 Prepend CLS token \([197,\ 768]\)
    4 Add positional encoding \([197,\ 768]\)
    5 Layer Norm \([197,\ 768]\)
    6a Project to Q, K, V \([197,\ 64]\) each
    6b \(QK^\top / \sqrt{d_k}\) \([197,\ 197]\)
    6c Softmax (row-wise) \([197,\ 197]\)
    6d \(\hat{A} V\) \([197,\ 64]\)
    6e Output projection \(W_O\) \([197,\ 768]\)
    7 Residual add \([197,\ 768]\)
    8 LN + MLP + residual \([197,\ 768]\)
    9 — classification CLS token \(\rightarrow\) linear \([K]\)
    10 — segmentation Patch tokens \(\rightarrow\) reshape \(\rightarrow\) upsample \([224,\ 224,\ K]\)