from Guide to Machine Learning on Apr 16, 2023

Illustrated Intuition for Self-Attention

Large Language Models transform input words into output words, and to do this, they need to incorporate context across all inputs. The module that adds this context is called self-attention — the focus of this post.

Large Language Models rely on a fundamental building block called transformers, which we covered in Illustrated Intuition for Transformers. In this post, we dive deeper into how the transformer operates, illustrating one of its core functions — namely, to add context to its inputs via a module called self-attention.

Previously, we built the transformer pipeline from the ground up, to become the following 6 steps:

  1. Encode each word in our prompt, into a vector.
  2. Contextualize each prompt vector.
  3. Transform each prompt vector.
  4. Contextualize the previous output vector with the prompt vectors.1
  5. Transform this contextualized previous vector into the next vector.
  6. Decode the next vector into the next word. Output this next word. Repeat this process until the next word is end-of-sequence.

In this post, we will cover the "contextualize" step in more detail. This is illustrated below — the weighted sum right before both transform steps.

Even though we've introduced contextualization as a weighted sum, we haven't actually discussed how the weights are computed. The core question becomes then: How do we compute the weights in our weighted sum?

1. How to add context

One natural way to add context is to look for pre-determined patterns.

For example, if we see a proposition like "under," the words immediately after can add context for the meaning of "under". Here are a few example patterns:

The list could continue for a long while. At a high-level, anytime we see the word "under", we look for these patterns. If any pattern shows up, incorporate that context into a new, contextualized version of "under".

Here's what the process would look like in more detail. Let's use the first pattern, "under the bridge" as an example.

  1. Take the first 3 words of our prompt. Are these 3 words equal to "under the bridge"? If so, add "the" and "bridge" to the word "under"; we've contextualized the word "under" successfully. If not, the pattern has not detected any useful context; just return the original word unperturbed.
  2. Take the next 3 words of our prompt. Are these 3 words equal to "under the bridge"? If so, add "the" and "bridge" to the word "under". If not, return the original word.
  3. Take the following 3 words of our prompt. Are these 3 words equal to "under the bridge"? If so, add "the" and "bridge" to the word "under". If not, return the original word.
  4. Continue this until we reach the end of our sequence.

At the end of this process, every instance of "under" with the context "under the bridge" has been successfully contextualized.

Now, let's make this process more concrete.

  1. Collect sequence. Take the first 3 words of our prompt, $w_1, w_2, w_3$. Take the 3 words in our pattern, $p_1, p_2, p_3$.
  2. Compute similarity. For each word in our prompt, see if it matches the corresponding word in our pattern. Compute $\langle w_1, p_1\rangle, \langle w_2, p_2 \rangle, \langle w_3, p_3 \rangle$.
  3. Weighted sum. If it matches, add that word to our output word, $\tilde{w_1} = \langle w_1, p_1 \rangle w_1 + \langle w_2, p_2 \rangle w_2 + \langle w_3, p_3 \rangle w_3$.
  4. Run on all sequences. We repeat for every 3 words in our prompt, obtaining a set of contextualized words $\tilde{w_1}, \tilde{w_2} …$

Note our new words contain context only if the context is similar to the pattern "under the bridge"5. The key idea is to test for similarity between the prompt and predetermined patterns.

This approach works but is very computationally expensive, for two reasons2:

  1. Many possible patterns. Our technique of manually constructing patterns is laborious. Even if we automagically learned these patterns from our data, we would need many many patterns to represent all possible ways to add context to a word3. This is exceptionally expensive to run: For every single pattern, run the above process of searching the entire prompt for that pattern.
  2. Patterns have fixed size. Say we want to add context to a question mark. Usually, the first word of the sentence tells us the kind of question — who, when, what, why, how. However, patterns have fixed size; there is no way to design a pattern that can grab the first word and the last word, for a sentence of any length. As a workaround, we can simply add one pattern for every possible sentence length, aggravating the already-existent problem of too many patterns.4

The crux of the issue is that our patterns are independent of the input. Regardless of the input, we will always check for the same set of patterns, even if a significant chunk obviously don't apply.

In short, patterns can't capture variable-length patterns or a large numbers of patterns, without inefficient computational costs.

2. Add context more efficiently

To handle the above challenges, we make two fixes. For our first fix, we alter our patterns to be input-dependent. One simple way to do this is to replace the fixed pattern with the input target word. In other words, instead of checking similarity with a fixed pattern, check similarity with the target word.

  1. Collect sequence. Take the first 3 words of our prompt, $w_1, w_2, w_3$.
  2. Compute similarity. For each word in our prompt, see if it matches the first word. Compute $\langle w_1, w_1\rangle, \langle w_2, w_1 \rangle, \langle w_3, w_1 \rangle$.
  3. Weighted sum. If it matches, add that word to our output word, $\tilde{w_1} = \langle w_1, w_1 \rangle w_1 + \langle w_2, w_1 \rangle w_2 + \langle w_3, w_1 \rangle w_3$.
  4. Run on all inputs. We repeat for every 3 words in our prompt, obtaining a set of contextualized words $\tilde{w_1}, \tilde{w_2} …$

This version of contextualization certainly seems silly. Now, we contextualize each word by taking a weighted sum of all the similar words. We'll make this less strange in the next step. For now, focus on the fact that our patterns are now input-dependent.

Second, we make the pattern global. Rather than look for patterns of a specific, small length, we cover interactions with all other words in the prompt, at once. Our pipeline now looks like the following:

  1. Collect sequence. Take all words in our prompt, $w_1, w_2, ... w_n$.
  2. Compute similarity. For each word in our prompt, see if it matches the first word. Compute $\langle w_1, w_1\rangle, \langle w_2, w_1 \rangle, ... \langle w_n, w_1 \rangle$.
  3. Weighted sum. If it matches, add that word to our output word, $\tilde{w_1} = \langle w_1, w_1 \rangle w_1 + \langle w_2, w_1 \rangle w_2 + \cdots + \langle w_n, w_1 \rangle w_3$.
  4. Run on all inputs. We repeat for every target word in our prompt, obtaining a set of contextualized words $\tilde{w_1}, \tilde{w_2} ... \tilde{w_n}$

We've now successfully contextualized every word, using all other words, in an input-dependent way. We'll continue to evolve this contextualization process so that $\tilde{w_1}, \tilde{w_2} ... \tilde{w_n}$ can eventually represent many patterns implicitly.

The key idea is that we use all inputs when contextualizing each input, introducing the idea of a global awareness. We also introduce the idea of input-dependent patterns.

3. Add context using importance

This seems a bit strange6: One word is considered "context" for another word if the two words are similar. However, similarity doesn't necessarily indicate importance, or how important the two words are for understanding each other.

For example, "The horse, yellow-colored like a banana, went bananas." In this case, "banana" and "bananas" — although similar words — are completely irrelevant for understanding each other.

Knowing this, we transform every word. In particular, after this transformation, two words are highly similar if the two words alter each other's meaning. In other words, we define a new space where similarity is importance.

We'll call this transformation $K$.

  1. Compute similarity. For each word in our prompt, see if it is important to the first word, by computing similarity in an "importance" vector space. Compute $\langle Kw_1, Kw_1\rangle, \langle Kw_2, Kw_1 \rangle, \langle Kw_3, Kw_1 \rangle$.
  2. Weighted sum. If it matches, add that word to our output word, $\tilde{w_1} = \langle Kw_1, Kw_1 \rangle w_1 + \langle Kw_2, Kw_1 \rangle w_2 + \langle Kw_3, Kw_1 \rangle w_3$.
  3. Run on all inputs. We repeat for every target word in our prompt, obtaining a set of contextualized words $\tilde{w_1}, \tilde{w_2} ... \tilde{w_n}$

The key idea is that importance and similarity are two different ideas; formally, we introduce a vector space $K$, where similarity is importance. However, as we'll discuss next, importance is not a two-way street.

4. Importance is asymmetric.

Importance between words is asymmetric.

Say we have the word "cooler". To understand what this word means, we need context, as it may refer to a container to keep food cool, may describe a low temperature, or indicate one person is more attractive than another. However, if we write, "groceries in cooler," we immediately understand what "cooler" means. In this sense, the word "groceries" is very important for understanding "cooler".

The opposite is not true, however. The word "groceries" is meaningful on its own, and the word "cooler" doesn't change its meaning. Knowing that, the word "cooler" isn't very important to the word "groceries".

Knowing this, we modify our pipeline once more. Instead of using one vector space to represent importance generally, we define two vector spaces to represent asymmetric importance: one vector space for the target word and another vector space for the context word. If word A is important for understanding word B, context A will be highly similar to target B.

We'll name the two matrices that transform a word into these spaces $Q$, for "query" target word and $K$ for "key" context word.

  1. Compute similarity. For each word in our prompt, see if it is important for the first word, by computing similarity in an "importance" vector space. Compute $\langle Kw_1, Qw_1\rangle, \langle Kw_2, Qw_1 \rangle, \langle Kw_3, Qw_1 \rangle$.
  2. Weighted sum. If it matches, add that word to our output word, $\tilde{w_1} = \langle Kw_1, Qw_1 \rangle w_1 + \langle Kw_2, Qw_1 \rangle w_2 + \langle Kw_3, Qw_1 \rangle w_3$.
  3. Run on all inputs. We repeat for every target word in our prompt, obtaining a set of contextualized words $\tilde{w_1}, \tilde{w_2} ... \tilde{w_n}$

The key idea is to introduce two vector spaces, one $Q$ or "query" and the other $K$ or "key"; similarity tells us how important a key is for understanding a query.

5. Generalizing self-attention

The above allows us to contextualize words in our prompt with other words in our prompt.

However, say one input is the prompt for the model to respond to. Say another input is context for the prompt. Yet another dictates the response length and tone. In short, we may have multiple inputs that contribute to the output differently. In fact, we may have inputs from different modalities entirely — inputs that encode images or audio, for example.

To support different input types, we add a new transformation $V$, or the "value" transformation. This matrix transforms all inputs into a shared space, so that inputs can be summed together meaningfully.

  1. Compute similarity. For each word in our prompt, see if it is important for the first word, by computing similarity in an "importance" vector space. Compute $\langle Kw_1, Qw_1\rangle, \langle Kw_2, Qw_1 \rangle, \langle Kw_3, Qw_1 \rangle$.
  2. Weighted sum. If it matches, transform that word in a shared vector space, then add the transformed words together, $\tilde{w_1} = \langle Kw_1, Qw_1 \rangle Vw_1 + \langle Kw_2, Qw_1 \rangle Vw_2 + \langle Kw_3, Qw_1 \rangle Vw_3$.
  3. Run on all inputs. We repeat for every target word in our prompt, obtaining a set of contextualized words $\tilde{w_1}, \tilde{w_2} ... \tilde{w_n}$

This is the final form for our self-attention module — the first contextualization step in our model, which adds context to the inputted prompt.

The beauty of self-attention is that queries can represent any input. Keys and values can separately represent any other input. These two inputs are completely decoupled, which we'll leverage in the next section.

6. Using self-attention

Here we re-include the diagram from the very beginning of this article, which summarizes next-word prediction.

Notice there's a second weighted sum in orange. This is a part of the decoder. We now need to explain how that orange weighted sum is computed. In short, that weighted sum needs to contextualize the previous words (gray, blue) using the prompt (orange).

We can now use the generalized version of self-attention above to contextualize previous words with the prompt. Before, queries, keys, and values were all various embeddings of the prompt. Now, queries are embeddings of previously-generated output words; keys and values are embeddings of the prompt, providing context.

This concludes our explanation of self-attention. We've covered both contextualization steps in the model, the most critical part of the transformer.

Summary

We've now completed the self-attention module and explained the core parts of a transformer. To summarize:

  1. We used a Q matrix to transform the target or "query" word. We used a K matrix to transform the context or "key" word.
  2. In self-attention, we compute similarity between queries and keys. Each of these dot products tells us: How important is the key for understanding the query?
  3. In the decoder, queries are previous outputs and keys, values are the prompt.

At this point, the original transformer paper "Attention is all you need" by Vaswani et al should be readable, as we've covered the core components.

For a final, practical introduction to the transformer, checkout the final part of this 3-part series in Practical Introduction to Large Language Models.


back to Guide to Machine Learning



  1. If we're decoding the first word, the "previous output vector" is a special vector that represents start-of-sequence

  2. In effect, these are problems with filters in a convolution. However, note we perform computation differently in this example. 

  3. One way around the many-patterns problem is to define some "low-level patterns". Earlier patterns will capture commonly-occurring grammars, like "under the" or "over the". Later patterns can recognize longer patterns, such as "under the bridge," "under the pseudonym", and others. In effect, we're emulating convolutions in computer vision. 

  4. Our goal is to capture long-range patterns, across many words or even sentences. To do this, as with convolutions, we have two options: Design one really big filter and incur expensive computational costs, or cascade many small filters, so that many small filters working together can eventually represent one really big filter. Both are computationally inefficient ways of represent long-term dependencies between distant parts of text. 

  5. Granted, the explanation does not match the formulation exactly. The formulation should first detect the pattern, using a similarity score $s_1 = \langle p_1, w_1\rangle + \langle p_2, w_2 \rangle + \langle p_3, w_3 \rangle$. Then, threshold or use that similarity score to add context, such as $\tilde{w_1} = \mathbb{1}_{s_1 > 0.5} \frac{w_1 + w_2 + w_3}{3} + \mathbb{1}_{s_1 \leq 0.5} w_1$. I made a super rough approximation that actually detects if any of the words in the prompt coincide with words in the prompt. If any individual word coincides, that word is added as context. In this way, it's a pretty strange pattern-detector, but it's probably the best we can get by interpolating pattern detectors and attention mechanisms. 

  6. This idea isn't without precedent. This would be the one-dimensional analogy for language to non-local means in computer vision. For every target pixel, take the average of all other pixels, weighted by that other pixel's similarity to the target pixel. We will use a similar idea: For every target word, take the average of all other words, weighted by that other word's similarity to the target word.