from Guide to Machine Learning on Apr 16, 2023

# Illustrated Intuition for Self-Attention

Large Language Models transform input words into output words, and to do this, they need to incorporate context across all inputs. The module that adds this context is called self-attention — the focus of this post.

Large Language Models rely on a fundamental building block called transformers, which we covered in Illustrated Intuition for Transformers. In this post, we dive deeper into how the transformer operates, illustrating one of its core functions — namely, to add context to its inputs via a module called self-attention.

Previously, we built the transformer pipeline from the ground up, to become the following 6 steps:

1. Encode each word in our prompt, into a vector.
2. Contextualize each prompt vector.
3. Transform each prompt vector.
4. Contextualize the previous output vector with the prompt vectors.1
5. Transform this contextualized previous vector into the next vector.
6. Decode the next vector into the next word. Output this next word. Repeat this process until the next word is end-of-sequence.

In this post, we will cover the "contextualize" step in more detail. This is illustrated below — the weighted sum right before both transform steps.

Even though we've introduced contextualization as a weighted sum, we haven't actually discussed how the weights are computed. The core question becomes then: How do we compute the weights in our weighted sum?

## 1. How to add context

One natural way to add context is to look for pre-determined patterns.

For example, if we see a proposition like "under," the words immediately after can add context for the meaning of "under". Here are a few example patterns:

• "under the bridge"
• "underdog"
• "underwhelmed"
• "under a different name"

The list could continue for a long while. At a high-level, anytime we see the word "under", we look for these patterns. If any pattern shows up, incorporate that context into a new, contextualized version of "under".

Here's what the process would look like in more detail. Let's use the first pattern, "under the bridge" as an example.

1. Take the first 3 words of our prompt. Are these 3 words equal to "under the bridge"? If so, add "the" and "bridge" to the word "under"; we've contextualized the word "under" successfully. If not, the pattern has not detected any useful context; just return the original word unperturbed.
2. Take the next 3 words of our prompt. Are these 3 words equal to "under the bridge"? If so, add "the" and "bridge" to the word "under". If not, return the original word.
3. Take the following 3 words of our prompt. Are these 3 words equal to "under the bridge"? If so, add "the" and "bridge" to the word "under". If not, return the original word.
4. Continue this until we reach the end of our sequence.

At the end of this process, every instance of "under" with the context "under the bridge" has been successfully contextualized.

Now, let's make this process more concrete.

1. Collect sequence. Take the first 3 words of our prompt, $w_1, w_2, w_3$. Take the 3 words in our pattern, $p_1, p_2, p_3$.
2. Compute similarity. For each word in our prompt, see if it matches the corresponding word in our pattern. Compute $\langle w_1, p_1\rangle, \langle w_2, p_2 \rangle, \langle w_3, p_3 \rangle$.
3. Weighted sum. If it matches, add that word to our output word, $\tilde{w_1} = \langle w_1, p_1 \rangle w_1 + \langle w_2, p_2 \rangle w_2 + \langle w_3, p_3 \rangle w_3$.
4. Run on all sequences. We repeat for every 3 words in our prompt, obtaining a set of contextualized words $\tilde{w_1}, \tilde{w_2} â€¦$

Note our new words contain context only if the context is similar to the pattern "under the bridge"5. The key idea is to test for similarity between the prompt and predetermined patterns.

This approach works but is very computationally expensive, for two reasons2:

1. Many possible patterns. Our technique of manually constructing patterns is laborious. Even if we automagically learned these patterns from our data, we would need many many patterns to represent all possible ways to add context to a word3. This is exceptionally expensive to run: For every single pattern, run the above process of searching the entire prompt for that pattern.
2. Patterns have fixed size. Say we want to add context to a question mark. Usually, the first word of the sentence tells us the kind of question — who, when, what, why, how. However, patterns have fixed size; there is no way to design a pattern that can grab the first word and the last word, for a sentence of any length. As a workaround, we can simply add one pattern for every possible sentence length, aggravating the already-existent problem of too many patterns.4

The crux of the issue is that our patterns are independent of the input. Regardless of the input, we will always check for the same set of patterns, even if a significant chunk obviously don't apply.

In short, patterns can't capture variable-length patterns or a large numbers of patterns, without inefficient computational costs.

## 2. Add context more efficiently

To handle the above challenges, we make two fixes. For our first fix, we alter our patterns to be input-dependent. One simple way to do this is to replace the fixed pattern with the input target word. In other words, instead of checking similarity with a fixed pattern, check similarity with the target word.

1. Collect sequence. Take the first 3 words of our prompt, $w_1, w_2, w_3$.
2. Compute similarity. For each word in our prompt, see if it matches the first word. Compute $\langle w_1, w_1\rangle, \langle w_2, w_1 \rangle, \langle w_3, w_1 \rangle$.
3. Weighted sum. If it matches, add that word to our output word, $\tilde{w_1} = \langle w_1, w_1 \rangle w_1 + \langle w_2, w_1 \rangle w_2 + \langle w_3, w_1 \rangle w_3$.
4. Run on all inputs. We repeat for every 3 words in our prompt, obtaining a set of contextualized words $\tilde{w_1}, \tilde{w_2} â€¦$

This version of contextualization certainly seems silly. Now, we contextualize each word by taking a weighted sum of all the similar words. We'll make this less strange in the next step. For now, focus on the fact that our patterns are now input-dependent.

Second, we make the pattern global. Rather than look for patterns of a specific, small length, we cover interactions with all other words in the prompt, at once. Our pipeline now looks like the following:

1. Collect sequence. Take all words in our prompt, $w_1, w_2, ... w_n$.
2. Compute similarity. For each word in our prompt, see if it matches the first word. Compute $\langle w_1, w_1\rangle, \langle w_2, w_1 \rangle, ... \langle w_n, w_1 \rangle$.
3. Weighted sum. If it matches, add that word to our output word, $\tilde{w_1} = \langle w_1, w_1 \rangle w_1 + \langle w_2, w_1 \rangle w_2 + \cdots + \langle w_n, w_1 \rangle w_3$.
4. Run on all inputs. We repeat for every target word in our prompt, obtaining a set of contextualized words $\tilde{w_1}, \tilde{w_2} ... \tilde{w_n}$

We've now successfully contextualized every word, using all other words, in an input-dependent way. We'll continue to evolve this contextualization process so that $\tilde{w_1}, \tilde{w_2} ... \tilde{w_n}$ can eventually represent many patterns implicitly.

The key idea is that we use all inputs when contextualizing each input, introducing the idea of a global awareness. We also introduce the idea of input-dependent patterns.

## 3. Add context using importance

This seems a bit strange6: One word is considered "context" for another word if the two words are similar. However, similarity doesn't necessarily indicate importance, or how important the two words are for understanding each other.

For example, "The horse, yellow-colored like a banana, went bananas." In this case, "banana" and "bananas" — although similar words — are completely irrelevant for understanding each other.

Knowing this, we transform every word. In particular, after this transformation, two words are highly similar if the two words alter each other's meaning. In other words, we define a new space where similarity is importance.

We'll call this transformation $K$.

1. Compute similarity. For each word in our prompt, see if it is important to the first word, by computing similarity in an "importance" vector space. Compute $\langle Kw_1, Kw_1\rangle, \langle Kw_2, Kw_1 \rangle, \langle Kw_3, Kw_1 \rangle$.
2. Weighted sum. If it matches, add that word to our output word, $\tilde{w_1} = \langle Kw_1, Kw_1 \rangle w_1 + \langle Kw_2, Kw_1 \rangle w_2 + \langle Kw_3, Kw_1 \rangle w_3$.
3. Run on all inputs. We repeat for every target word in our prompt, obtaining a set of contextualized words $\tilde{w_1}, \tilde{w_2} ... \tilde{w_n}$

The key idea is that importance and similarity are two different ideas; formally, we introduce a vector space $K$, where similarity is importance. However, as we'll discuss next, importance is not a two-way street.

## 4. Importance is asymmetric.

Importance between words is asymmetric.

Say we have the word "cooler". To understand what this word means, we need context, as it may refer to a container to keep food cool, may describe a low temperature, or indicate one person is more attractive than another. However, if we write, "groceries in cooler," we immediately understand what "cooler" means. In this sense, the word "groceries" is very important for understanding "cooler".

The opposite is not true, however. The word "groceries" is meaningful on its own, and the word "cooler" doesn't change its meaning. Knowing that, the word "cooler" isn't very important to the word "groceries".

Knowing this, we modify our pipeline once more. Instead of using one vector space to represent importance generally, we define two vector spaces to represent asymmetric importance: one vector space for the target word and another vector space for the context word. If word A is important for understanding word B, context A will be highly similar to target B.

We'll name the two matrices that transform a word into these spaces $Q$, for "query" target word and $K$ for "key" context word.

1. Compute similarity. For each word in our prompt, see if it is important for the first word, by computing similarity in an "importance" vector space. Compute $\langle Kw_1, Qw_1\rangle, \langle Kw_2, Qw_1 \rangle, \langle Kw_3, Qw_1 \rangle$.
2. Weighted sum. If it matches, add that word to our output word, $\tilde{w_1} = \langle Kw_1, Qw_1 \rangle w_1 + \langle Kw_2, Qw_1 \rangle w_2 + \langle Kw_3, Qw_1 \rangle w_3$.
3. Run on all inputs. We repeat for every target word in our prompt, obtaining a set of contextualized words $\tilde{w_1}, \tilde{w_2} ... \tilde{w_n}$

The key idea is to introduce two vector spaces, one $Q$ or "query" and the other $K$ or "key"; similarity tells us how important a key is for understanding a query.

## 5. Generalizing self-attention

The above allows us to contextualize words in our prompt with other words in our prompt.

However, say one input is the prompt for the model to respond to. Say another input is context for the prompt. Yet another dictates the response length and tone. In short, we may have multiple inputs that contribute to the output differently. In fact, we may have inputs from different modalities entirely — inputs that encode images or audio, for example.