from Guide to Machine Learning on May 5, 2024

How Flash Attention works

In previous posts, we saw how jointly tiling the two matrix multiplies in self-attention can cut down on the number of DRAM reads and writes. To explain this, we covered matrix multiplication from the ground up in a 3-part series covering efficient matrix multiply kernels:

However, there was a crucial detail that I skipped over — namely, how to "tile softmax".

The Problem with Softmax

To start, let me explain the issue by revisiting the equation for softmax.

$$\textrm{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}}$$

Notice the denominator in softmax; this sum requires every element of $x$ to compute. By extension, if we compute softmax along a row of a matrix, we need to read every value in that row. Unfortunately for us, Flash Attention never has access to the entire row of values. Here's what I mean: Here is the equation for self-attention again.

$$\textrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V$$

In Flash Attention, both the matrix multiplies $QK^T$ and $\textrm{softmax}(…)V$ are jointly tiled. This means that we only have a single tile of $QK^T$ at any given moment in time — in other words, only parts of rows. This begs the question then: How do we properly compute softmax when it requires entire rows — when we only have access to parts of rows?

Streaming Mean

Our problem above is what's called a streaming problem. We're only streamed blocks of the input at a time, and we need to eventually return a final result that is based on all the inputs we've seen.

Let's see another steaming algorithm, to understand the gist. Say we have an infinite stream of numbers, and we want to compute the mean of all numbers we've seen so far, with $O(1)$ storage. There's a simple trivial solution: Simply store

  1. A running total of all the values you've seen so far
  2. A running count of how many values you've seen so far

To return the mean at any time, simply return the total over the count.

However, let's say the values in our stream are fairly large — so large in fact that we would overflow even float64 summing several thousand values. So, let's try a different approach. Now, store

  1. A running mean over all inputs you've seen so far
  2. A running count of how many inputs you've seen so far

Say we now get a new value from the stream. We have the old mean $\mu_{k-1} = \frac{\sum_{i=1}^{k-1}x_i}{k-1}$, the old count $k-1$, and the new value $x_k$. We can now represent the new mean in terms of these three values.

$$\begin{align} \mu_k &= \frac{\sum_{i=1}^{k}x_i}{k} \\ &= \frac{\sum_{i=1}^{k-1}x_i}{k} + \frac{x_k}{k} \\ &= \underbrace{\frac{\sum_{i=1}^{k-1}x_i}{k-1}}_{\mu_{k-1}}\frac{k-1}{k} + x_k\frac{1}{k}\\ &= \mu_{k-1} \frac{k-1}{k} + x_k \frac{1}{k} \end{align}$$

In effect, we compute a weighted average between the old mean and the new value. This works because the first fraction $\frac{k-1}{k}$ acts as a correction factor — it gets rid of the old denominator $k-1$ and applies a new denominator $k$. And ta-da! We now have an algorithm for computing the mean of an infinite stream of values, with $O(1)$ memory no less.

Numerically Stable Softmax

We've seen how the streaming mean works; let's now apply this methodology to Flash Attention. First recall how softmax is used in self-attention: A row of softmax'ed values is dotted with a column of $V$. Our goal is to ensure the final, accumulated dot product is correct.

Before figuring out how to stream softmax, let's first revisit the formula for softmax, as we presented it earlier.

$$\textrm{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}}$$

However large values $x_i$ in the exponent can cause overflow errors fairly easily. To avoid this, we subtract a constant from every value in $x_i$

$$\textrm{softmax'}(x)_i = \frac{e^{x_i + C}}{\sum_j e^{x_j + C}}$$

Using some algebra, we can show that this new softmax' is identical to the original softmax from above, using the identity $e^{a+b} = e^ae^b$.

$$\begin{align} \textrm{softmax'}(x)_i &= \frac{e^{x_i + C}}{\sum_j e^{x_j + C}}\\ &= \frac{e^{x_i}e^C}{\sum_j e^{x_j}e^C}\\ &= \frac{e^{x_i}e^C}{e^C\sum_j e^{x_j}}\\ &= \frac{e^{x_i}}{\sum_j e^{x_j}}\\ &= \textrm{softmax}(x)_i \end{align}$$

So long as the constant $C$ doesn't vary with $i$, we can successfully pull the term $e^C$ out of the summation. To this end, we define $C = -\textrm{max}_j x_j = -m$. Our new, numerically stable variant of softmax, is now

$$\textrm{softmax'}(x)_i = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}$$

Streaming Softmax Denominator

For our first step, let's figure out how to write a streaming algorithm for just the maximum and summation in the denominator. In this case, we can store

  1. The maximum of all values so far, $m_k$
  2. The sum of all exponentiated terms so far, $\ell_k = \sum_{j=1}^k e^{x_j - m_k}$

When we receive the next value in the stream $x_k$, we would have the previous maximum $m_{k-1}$ and the previous summation $\ell_{k-1}$. Fortunately for us, updating the maximum is pretty straightforward; take the maximum of the new value and the old maximum.

$$m_k = \textrm{max}(m_{k-1}, x_k)$$

Updating the summation is slightly trickier, but we can use the same idea from the streaming mean algorithm above. Just like with the streaming mean, there's an old denominator we need to erase $e^{-m_{k-1}}$ and a new one we need to apply $e^{-m_k}$. Using this principle, we can now represent the new summation in terms of the previous summation $\ell_{k-1}$, previous maximum $m_{k-1}$, and the new maximum $m_k$.

$$\begin{align} \ell_k &= \sum_{j=1}^k e^{x_j - m_k} \\ &= \sum_{j=1}^{k-1} e^{x_j - m_k} + e^{x_k - m_k} \\ &= \sum_{j=1}^{k-1} e^{x_j - m_k + (m_{k-1} - m_{k-1})} + e^{x_k - m_k} \\ &= e^{m_{k-1} - m_k} \underbrace{\sum_{j=1}^{k-1} e^{x_j - m_{k-1}}}_{\ell_{k-1}} + e^{x_k - m_k} \\ &= e^{m_{k-1} - m_k} \ell_{k-1} + e^{x_k - m_k} \end{align}$$

We've got a streaming softmax denominator! We can now use this to compute the correctly softmax'ed value for the $k$-th value, accounting only for the previous $1…k-1$ values.

Streaming Softmax Dot Product

Now that we're able to stream the denominator, we can correctly stream the accumulated dot product as well. Let's say we have a partially computed dot product that accounts for the first k terms. Here's what that looks like.

$$p_k = \sum_{i=1}^k \underbrace{\frac{e^{x_i - m_k}}{\ell_k}}_{\textrm{softmax'}(x)_i} y_i$$

Notice that the terms $\ell_k$ and $m_k$ are both independent of $i$, so both can be pulled out of the summation.

$$p_k = \frac{e^{-m_k}}{\ell_k} \sum_{i=1}^k e^{x_i} y_i$$

Now we have an interesting observation: Notice that the dot product so far is simply off by a constant factor, independent of $i$! This setup is just like all of our previous streaming algorithms. When calculating the k-th term from the (k-1)-th term, simply remove the previous, outdated denominator and apply a new, updated denominator.

Since we've done this a few times, I won't walk through all of the algebra. Instead, I'll just write the dot product for k terms $p_k$ as a function of the previous dot product for k-1 terms $p_{k-1}$.

$$p_k = \frac{e^{-m_k}}{\ell_k} \frac{\ell_{k-1}}{e^{-m_{k-1}}} p_{k-1} + \frac{e^{-m_k}}{\ell_k} e^{x_k} y_k$$

Now, we can use this expression to continuously update our accumulated dot product. At any point in time, including the very end when we've finished streaming all tiles, we will now have a correctly computed dot product with softmax outputs. Ta-da!

Takeaways

With this, we've now covered the major components of Flash Attention: Jointly tile two back-to-back matrix multiplies to cut down on memory accesses, and stream the softmax dot product. Generally speaking, we simply applied a "correction factor" to our intermediate result, ensuring at every point along the stream that we had a correct statistic.

Streaming algorithms aren't just central to Flash Attention however. They're also critical for a different latency optimization for Large Language Models — namely, speculative decoding. We cover that next in How speculative decoding works.


back to Guide to Machine Learning