from Guide to Machine Learning on Apr 23, 2023

Practical Introduction to Large Language Models

We've previously motivated and built the transformer from the ground up in Illustrated Intuition for Transformers, even diving deep into what is termed the "self-attention" module, via Illustrated Intuition for Self-Attention. Those posts construct the transformer from first principles.

In this post, we'll summarize what a transformer is. Consider this the more comprehensive guide, providing a thorough coverage of the architecture, including conventions and best practices. We'll write expressions for the entire model with fully-specified dimensions, including usual values for different variables — as well as recent modifications from state-of-the-art open-source models.

Self-attention Projections

The self-attention module first transforms the input $X$ into three different vector spaces.

Let $X \in \mathbb{R}^{n \times d_\text{model}}, W_{\{Q, K, V\}} \in \mathbb{R}^{d_\text{model} \times d}$. Here, we use $n$ to represent the input sequence length and $d_\text{model}$ to represent the dimensionality of each token.

Compute the three projections:

$$\begin{align}Q &= XW_Q \in \mathbb{R}^{n \times d}\\ K &= XW_K \\ V &= XW_V\end{align}$$

Note that this computation can be more efficiently parallelized if we simply stack the weight matrices together and produce one concatenated output. This depends again on the fact that $d_v = d_k = d$.

$$\begin{bmatrix}Q\\K\\V\end{bmatrix} = X\begin{bmatrix}W_Q\\W_K\\W_V\end{bmatrix}$$

Now, $\{Q, K, V\} \in \mathbb{R}^{n \times d}$. This projection is then followed by a bias or non-linearity. Here are a few common architectures and how they project the $Q$ and $K$ values.

Architecture Projection
OPT $Q = XW_Q + b_Q$
LlaMA $Q = r(XW_Q)$

To capture all the different ways of projecting $X$, we'll represent these embeddings of $X$ as

$$X_{\{Q,K,V\}} = P_{\{Q,K,V\}}(X)$$

This step requires a total of $3dd_\text{model}$ parameters and outputs $3nd$ activations.

Self-attention

Next, we weight the $V$ projection using the $Q$ and $K$ projections, using the following expression for self-attention.

$$\begin{align}Y_a &= \text{Attention}(X_Q, X_K, X_V) \\ &= \text{softmax}(\frac{X_QX_K^T}{\sqrt{d}})X_V\end{align}$$

Notice $X_QX_K^T \in \mathbb{R}^{n \times n}$, providing a weight for all $n$ context key words when contextualizing each of $n$ target query words. We then take a weighted sum of all $n$ tokens, to generate the final output $Y_a \in \mathbb{R}^{n \times d}$.

In practice, softmax may be replaced during export, as it dominates latency on edge devices1.

This step didn't add any parameters but added several intermediate outputs: (1) an outer product that yields an $hn^2$ activation, (2) elementwise division that yields the same $hn^2$, and (3) a softmax that again yields a $hn^2$ activation, before (4) obtaining the final $nhd$ activation. This is a total of $(3n+d)hn$ activations.

Multi-head Attention

The transformer then repeats this attention computation several times in parallel. These repeats are called heads. The number of heads $h \in \{32, 40, 48, 64\}$ is generally no more than 100. Per head, we first transform $X$ into the $Q, K, V$ spaces to obtain

$$X_{\{Q,K,V\}} \in \mathbb{R}^{n \times hd}$$

Per head, we then perform attention computation above, where $0 \leq i < h$ indexes the head.

$$A_i = \text{Attention}(X_{Q,i}, X_{K,i}, X_{V,i}) \in \mathbb{R}^{n \times d}$$

We then concatenate all of these $Y_a$ and add another weight matrix $W_O \in \mathbb{R}^{hd \times d_\text{model}}$ to convert our concatenated output's dimensionality back into $d_\text{model}$ dimensions.

$$\begin{align}Y_A &= \text{Multihead}(X_Q,X_K,X_V)\\ &= \text{Concat}(\{A_i\}_{i=0}^h)W_O \in \mathbb{R}^{n \times d_\text{model}}\end{align}$$

In practice, you'll find that the attention matrices $W_{\{Q, K, V\}}$ matrices are $d_\text{model} \times hd$. For example, for LLaMA 7b, these weight matrices are $4096 \times (128 \cdot 32) = 4096 \times 4096$. Usually, $hd = d_\text{model} \in \{4096, 5120, 6144, 8192\}$.

This multi-head attention would then contain $4hdd_\text{model} \approx 4d_\text{model}^2$ parameters. This would generate an $nd_\text{model}$ intermediate activation after projecting $X$ into $Q,K,V$ spaces and output an $nd_\text{model}$ activation, after reading and writing $3ndh + (3n+d)hn = (3n + 4d)hn$ intermediate activations. For simplicity, if we assume this self-attention operation is "fused", as we discuss in When to tile two matrix multiplies, there are no intermediate activations to account for. We can then say that self-attention generates $2nd_\text{model}$ activations.

For example, this is $4 \cdot 4096^2 \approx 67 \times 10^6$ parameters for LLaMA 7B.

MLP

Immediately after self-attention is an MLP, operating on the output $Y_A$ we mentioned above. Let $W_U \in \mathbb{R}^{d \times 4d}, W_D \in \mathbb{R}^{4d \times d}$. Then, the MLP computes

$$Y_M = \text{MLP}(Y_A) = \sigma(Y_AW_U + b_U)W_D + b_D \in \mathbb{R}^{n \times d_\text{model}}$$

In actuality, $W_U \in \mathbb{R}^{d_\text{model} \times d_\text{hidden}}$ where $d_\text{hidden}$ can be any value. However, in practice, $d_\text{hidden} \approx 4d_\text{model}$, which is why we use $4d$ directly here, rather than introduce another variable. Note that HuggingFace and other architectures will provide a configurable hyperparameter for the hidden dimension $d_\text{hidden}$.

Architecture Projection
OPT $\text{MLP}(X) = \sigma(XW_U)W_D$
LlaMA $\text{MLP}(X) = (\text{silu}(XW_G) \otimes XW_U)W_D$

This MLP contains $2d_\text{model}d_\text{hidden} \approx 8d_\text{model}^2$ parameters and $n(d_\text{hidden} + d_\text{model}) = 5nd_\text{model}$ activations.

in LLaMa 7B, that's $(3 \cdot 2.6875) d_\text{model}^2 \approx 8d_\text{model}^2 = 8 \cdot 4096^2 \approx 135 \times 10^6$ parameters and $n(2d_\text{hidden} + d_\text{model}) = n(2 \dot 2.6875 + 1) d_\text{model} \approx 6.4nd_\text{model}$ activations.

Encoder

We now consider the self-attention and MLP modules to be one combined "encoder" module.

$$\text{Encoder}(X) = \text{MLP}(\text{Multihead}(X_Q,X_K,X_V))$$

Stack many encoders to obtain a final embedding

$$E = \text{Encode}(P) = \text{Encoder}(\cdots\text{Encoder}(P)) \in \mathbb{R}^{n \times d_\text{model}}$$

This embedding represents our prompt, contextualized.

Each encoder block uses $12d_\text{model}^2$ parameters or a total of $2nd_\text{model} + 5nd_\text{model} = 7nd_\text{model}$ activations — again, if we assume that the self-attention module was fused.

In LLaMA 7B's case, this is $12 \cdot 4096^2 \approx 202 \times 10^6$ or 202 million parameters.

Decoder

We'll repeat this for the next-word prediction portion of the model. In this case, we take the encoded prompt $E$ and use that for the keys and values. Queries are taken from the previous output word $X$.

$$\begin{align} \text{Decoder}(X, E) &= \text{MLP}(\text{Multihead}(X_Q,E_K,E_V))\\ &= \text{MLP}(\text{Multihead}(P_Q(X),P_K(X),P_V(X))\end{align}$$

Stack many decoders and use them auto-regressively.

$$\begin{align}Y_i &= \text{Decode}_{E := \text{Encode}(P)}(Y_{i-1})\\ &= \text{Decoder}(\cdots \text{Decoder}(Y_{i-1}, E), E)\end{align}$$

This means that we feed the previous outputs $Y_{i-1}$ as input to the model, to generate the next output $Y_i$.

De-tokenize

Now we have the output logits for the i-th word. To decode these logits back into a word, use your typical classification head, with a mapping

$W_\text{DeTok} \in \mathbb{R}^{d_\text{model} \times D}$ from $d_\text{model}$ to the dictionarize size $D$. This gives us

$$\text{DeTok}(Y) = \text{argmax}(\text{softmax}(W_\text{DeTok}Y))$$

In practice, the dictionary size is around 32,000.

Transformer

Here is our transformer in summary. In short, convert each set of logits into a word using your typical classification head.

$$\{\text{DeTok}(Y_i): i \in [0, s)\}$$

These logits come from an autoregressive decoding process, which predicts the next output word using the prompt and the previous output words. This iterative process is key for the Large Language Model's performance.

$$Y_i = \text{Transformer}(P, Y_{i-1}) = \text{Decode}_{\text{Encode}(P)}(Y_{i-1})$$

Both the encode and decode steps are actually cascaded sets of layers that look like the following. These layers consume and output tokens of the same dimensionality throughout the entire model. As a result, every encoder has the exact same architecture.

$$E = \text{Encode}(P) = \text{Encoder}(\cdots\text{Encoder}(P)) \in \mathbb{R}^{n \times d_\text{model}}$$

The goal of these layers is to (1) add context and (2) transform the contextualized tokens. The former is accomplished with a self-attention module and the latter is accomplished with an MLP. Put altogether, here is what an encoder looks like.

$$\text{Encoder}(X) = \text{MLP}(\text{Multihead}(X_Q,X_K,X_V))$$

In the above encoder, the self-attention layer looks like the following. The goal is to use an inner product to establish "importance" of one key word for contextualizing a query word. Then, these "importance" weights are used to take a weighted sum of the input, projected into value space.

$$\text{Attention}(X_Q, X_K, X_V) = \text{softmax}(\frac{X_QX_K^T}{\sqrt{d}})X_V \in \mathbb{R}^{n \times d}$$

The MLP looks like the following, where a very large hidden size has been empirically found to improve performance.

$$Y_M = \text{MLP}(Y_A) = \sigma(Y_AW_U)W_D \in \mathbb{R}^{n \times d_\text{model}}$$

Every one of these blocks also features $12d_\text{model}^2$ parameters and $7nd_\text{model}$ total activations.

And that's it! This concludes the transform architecture in its entirety, along with relevant concepts and a fully-specified expression for the model. This concludes our 4-part series on how Large Language Models work.


back to Guide to Machine Learning



  1. There are a number of papers that propose approximations for softmax, to run more quickly at inference time. See SOFT: Softmax-free Transformer with Linear Complexity by Lu et al or cosFormer: Rethinking Softmax In Attention by Qin et al.