I figure the best way to understand Vision Transformers is to trace the "concrete" tensor shape that runs throughout.
Setup and Notation
| Symbol | Value | Meaning |
| \(H, W, C\) | \(224, 224, 3\) | Image height, width, channels |
| \(P\) | \(16\) | Patch size |
| \(N\) | \(196\) | Number of patches (\(14 \times 14\)) |
| \(D\) | \(768\) | Embedding dimension |
| \(d_k\) | \(64\) | Attention head dimension |
| \(L\) | \(1\) | Number of encoder blocks |
| \(K\) | \(10\) | Number of output classes |
Step 1 — Patch Extraction
The image \(I \in \mathbb{R}^{224 \times 224 \times 3}\) is divided into a grid of non-overlapping
\(16 \times 16\) patches. Each patch is flattened into a vector:
\[\mathbf{x}_i \in \mathbb{R}^{P^2 \cdot C} = \mathbb{R}^{768} \qquad i = 1, \dots, 196\]
Stacked into a matrix:
\[X_{\text{patches}} \in \mathbb{R}^{196 \times 768}\]
Shape: \([196,\ 768]\)
Step 2 — Linear Patch Embedding
Raw pixel vectors live in a space whose axes are individual pixel intensities. This is a poor
space for reasoning — there is no semantic structure along those axes. The embedding projects
each patch into a learned space \(\mathbb{R}^D\) where geometrically meaningful relationships
can emerge.
\[Z_{\text{patches}} = X_{\text{patches}}\, W_E + \mathbf{b}_E \qquad W_E \in \mathbb{R}^{768 \times 768},\
\mathbf{b}_E \in \mathbb{R}^{768}\]
The analogy for a physicist: choosing the right basis for a Hamiltonian. Working in the pixel
basis is like staying in the position basis when the energy eigenbasis reveals all the structure.
\(W_E\) is learned — the training process finds the projection.
Shape: \([196,\ 768]\)
Step 3 — CLS Token
A single learnable vector \(\mathbf{z}_{\text{cls}} \in \mathbb{R}^{768}\) is prepended to the
sequence. It carries no pixel information — it is a trained parameter, initialized randomly and
updated by backpropagation. By the end of the encoder, it will have aggregated global context
from all other tokens through attention. For classification, it is the only token used.
\[Z = \left[\mathbf{z}_{\text{cls}}\ ;\ Z_{\text{patches}}\right] \in \mathbb{R}^{197 \times 768}\]
Shape: \([197,\ 768]\)
Step 4 — Positional Encoding
The attention operation (defined in Step 6) is permutation-invariant: shuffling the rows of
\(Z\) produces the same result. Spatial position must therefore be injected explicitly. A learned
matrix \(P_{\text{pos}} \in \mathbb{R}^{197 \times 768}\) is added element-wise:
\[Z \leftarrow Z + P_{\text{pos}}\]
One position vector per token, each a learned \(D\)-dimensional parameter. After this step every
token carries both what it looks like and where it is.
Shape: \([197,\ 768]\) — unchanged
Step 5 — Layer Normalization
Before attention, each token vector is independently normalized across its \(D\) feature dimensions:
\[\text{LN}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sigma + \epsilon} + \beta\]
where \(\mu\) and \(\sigma\) are the scalar mean and standard deviation of \(\mathbf{x} \in \mathbb{R}^D\),
and \(\gamma, \beta \in \mathbb{R}^D\) are learned scale and shift parameters.
Applied row-wise: each of the 197 vectors is normalized independently. Without this, as weights
update during training the distribution of activations shifts unpredictably, causing dot products
\(\mathbf{q} \cdot \mathbf{k}^\top\) to grow in an uncontrolled way. The learned \(\gamma, \beta\)
ensure the normalization does not discard representational capacity — the network can undo the
normalization if needed.
\[\tilde{Z} = \text{LN}(Z)\]
Shape: \([197,\ 768]\) — unchanged
Step 6 — Self-Attention
Motivation
A convolution at layer \(l\) sees only a \(k \times k\) neighborhood. Relating a patch in the
top-left corner to one in the bottom-right requires \(O(\text{image\_size}/k)\) stacked layers —
long-range dependencies are expensive and lossy.
Attention asks a different question: given a set of \(N\) vectors, how should each vector update
itself by aggregating information from all others, weighted by relevance? For defect detection,
the relevance of a patch often depends on context far away — a scratch is only meaningful given
knowledge of the surrounding material.
6a — Projecting into Q, K, V
Three learned matrices \(W_Q, W_K, W_V \in \mathbb{R}^{768 \times 64}\) project each token into
three separate 64-dimensional subspaces:
\[Q = \tilde{Z}\, W_Q, \qquad K = \tilde{Z}\, W_K, \qquad V = \tilde{Z}\, W_V\]
\[\mathbb{R}^{197 \times 768} \times \mathbb{R}^{768 \times 64} \rightarrow \mathbb{R}^{197 \times 64} \quad
\text{each}\]
All three come from the same \(\tilde{Z}\). This is self-attention — each token simultaneously
plays three roles. The three projection matrices are learned independently, each specializing
for a different role:
Step 7 — Residual Connection (post-attention)
\[Z \leftarrow Z + \text{MSA}(\text{LN}(Z))\]
\[\mathbb{R}^{197 \times 768} + \mathbb{R}^{197 \times 768} \rightarrow \mathbb{R}^{197 \times 768}\]
The residual adds the original \(Z\) back to the attention output. Two purposes: (1) a gradient
highway — gradients flow directly through the addition without passing through the attention
computation; (2) the attention block only needs to learn a correction to the identity, which
is an easier optimization target than learning the full transformation from scratch.
Shape: \([197,\ 768]\)
Step 8 — Layer Normalization + MLP Block
A second LayerNorm followed by a position-wise two-layer MLP applied independently to each token:
\[\text{MLP}(\mathbf{x}) = \text{GELU}(\mathbf{x}\, W_1 + \mathbf{b}_1)\, W_2 + \mathbf{b}_2\]
\[W_1 \in \mathbb{R}^{768 \times 3072}, \quad W_2 \in \mathbb{R}^{3072 \times 768}\]
Per token: \(\mathbb{R}^{768} \rightarrow \mathbb{R}^{3072} \rightarrow \mathbb{R}^{768}\). Applied to
all 197 tokens simultaneously. Followed by a second residual:
\[Z \leftarrow Z + \text{MLP}(\text{LN}(Z))\]
Attention mixes information across tokens. The MLP processes each token independently after
that mixing. The two operations are complementary — global routing followed by local refinement.
Shape: \([197,\ 768]\) — unchanged throughout
Step 9 — Classification Head
Extract the CLS token (row 0) and apply a linear classifier:
\[\hat{\mathbf{y}} = Z[0,\ :]\, W_{\text{cls}} \qquad W_{\text{cls}} \in \mathbb{R}^{768 \times K}\]
Shape: \([768] \rightarrow [K]\)
Step 10 — Segmentation Pivot
Classification uses the CLS token — one label per image, spatial information discarded.
Segmentation requires one label per pixel — spatial information must be preserved.
Discard the CLS token. Keep all patch tokens:
\[Z_{\text{patches}} = Z[1:,\ :] \in \mathbb{R}^{196 \times 768}\]
Reshape to a spatial grid:
\[\mathbb{R}^{196 \times 768} \rightarrow \mathbb{R}^{14 \times 14 \times 768}\]
This is a spatial feature map — structurally equivalent to a CNN backbone output, but every
location has been computed with global context. A lightweight decoder upsamples to pixel
resolution:
\[\mathbb{R}^{14 \times 14 \times 768} \rightarrow \mathbb{R}^{224 \times 224 \times K}\]
For the masked encoder variant: masking during pretraining forces the model to reconstruct
missing patches from context. The learned representations therefore encode *what is contextually
normal* for a given location. Defect patches are those whose representations deviate from what
the model predicts — the anomaly signal is baked into the representation space itself.
Complete Shape Trace
| Step | Operation | Shape |
| Input | Raw image | \([224,\ 224,\ 3]\) |
| 1 | Patch extraction + flatten | \([196,\ 768]\) |
| 2 | Linear embedding \(W_E\) | \([196,\ 768]\) |
| 3 | Prepend CLS token | \([197,\ 768]\) |
| 4 | Add positional encoding | \([197,\ 768]\) |
| 5 | Layer Norm | \([197,\ 768]\) |
| 6a | Project to Q, K, V | \([197,\ 64]\) each |
| 6b | \(QK^\top / \sqrt{d_k}\) | \([197,\ 197]\) |
| 6c | Softmax (row-wise) | \([197,\ 197]\) |
| 6d | \(\hat{A} V\) | \([197,\ 64]\) |
| 6e | Output projection \(W_O\) | \([197,\ 768]\) |
| 7 | Residual add | \([197,\ 768]\) |
| 8 | LN + MLP + residual | \([197,\ 768]\) |
| 9 — classification | CLS token \(\rightarrow\) linear | \([K]\) |
| 10 — segmentation | Patch tokens \(\rightarrow\) reshape \(\rightarrow\) upsample | \([224,\ 224,\ K]\) |