from Guide to Machine Learning on May 07, 2023
When to tile two matrix multiplies
Matrix multiplication is an extremely well-studied and well-optimized operation. However, what if you have two matrix multiplies, instead of just one?
What's less obvious is how to optimize successive matrix multiplies.Flash Attention proposes a mechanism for accomplishing this in the n=2
case, jointly tiling a pair of matrix multiplies that are applied directly one after another — using this, Flash Attention is able to speed up language model training speed by over 3x**! This begs the question: How do we jointly optimize two matrix multiplies? More generally, when is jointly optimizing multiple matrix multiplies beneficial, and when is it not?
How to optimize two matrix multiplies
Previously, we showed how to make a single matrix multiplication faster and more memory-efficient, using tiling. Effectively, we accomplished this by breaking down the matrix multiply into smaller, independent subproblems.
In this post, we'll extend this to two matrix multiplies, making both faster and more memory efficient. In particular, we show how to jointly tile two matrix multiplies, by again breaking down two matrix multiplies into smaller, independent subproblems. Here's what this means in more detail:
- Say we have two matrix multiplications.
- Partially compute the first matrix multiplication, to obtain a part of the intermediate output.
- Immediately use this partial intermediate output. Notice we didn't have to fully compute and store the intermediate output.
- Partially compute the second matrix multiplication, to obtain a part of our final output.
- Repeat this process for all parts of the input until we complete the final output.
In this way, we never represent the intermediate output explicitly, saving unnecessary reads and writes to memory; in turn, this saves latency. This is what we call "joint tiling1" of two matrix multiplies. Let's now discuss how this works in more detail.
How to tile two matrix multiplies
Say we have two matrix multiplications back-to-back, such as $y = ABC$. We could start by applying the above matrix multiplication twice. Unfortunately, there are quite a few fetches and writes:
- Fetch $A$ and $B$. Compute. Write the results to $\tilde{y}$.
- Fetch $\tilde{y}$ and $C$. Compute. Write the results to $y$.
Notice the extra reads and writes incurred for storing an intermediate $\tilde{y}$ result. Our goal is to eliminate those extra reads and writes, instead writing directly to $y$, without ever explicitly storing $\tilde{y}$.
Unfortunately, these savings aren't free; we need to recompute parts of our intermediate tensors multiple times and fetch our weight matrices repeatedly, in exchange.
To understand why, consider a single value of $y = \tilde{y}C$, our final output. That value requires a row vector in our intermediate output $\tilde{y}$ and a column vector in $C$.
- We reuse this row of $\tilde{y}$ repeatedly, once for each value in a row of $y$.
- To compute a row of $\tilde{y}$, we need the entire matrix $B$.
These two facts combined mean that we increase compute and memory accesses for our weight matrices, in exchange for fewer memory accesses for intermediate outputs.
This illustrates the intuitive takeaway: Jointly tile matrix multiplies if the intermediate outputs are larger than the weight matrices. Now, let's compute exactly when this tradeoff is favorable.
When to tile two matrix multiplies
Let's start with the computational cost for both approaches — the naive one with two separate matrix multiplications and the jointly-tiled one with two concurrently tiled matrix multiplies.
Say we wish to compute $\tilde{y} = AB \in \mathbb{R}^{m \times n}$ where $A \in \mathbb{R}^{m \times k}, B \in \mathbb{R}^{k \times n}$. In this case, for every one of the $mn$ outputs, we have $k$ multiplies and $k-1$ additions. This makes a total of $mn(2k - 1)$ total FLOPs. We'll abbreviate this to be $2mnk$ FLOPs.
Say we add another matrix multiplication so that $y = (AB)C \in \mathbb{R}^{m \times k}$ where $C \in \mathbb{R}^{n \times k}$. The naive approach would then incur
$$\text{FLOPs}_\text{naive} = 2mnk + 2mkn = 4mnk$$
FLOPs. In the jointly-tiled approach, we recompute each row of the intermediate output $\tilde{y}$ a total of $\lceil \frac{k}{b} \rceil$ times, for block size $b$. In other words, we recompute once for every block in a row of the final output $y$. This means in the fused approach, we have
$$\text{FLOPs}_\text{tiled} = 2mnk\lceil\frac{k}{b}\rceil + 2mkn = 2mnk(\lceil \frac{k}{b} \rceil + 1)$$
FLOPs. Notice that $b \leq k$, meaning the computational cost for tiling can only increase. Jointly tiled and naive FLOPs are identical when $k = b$. In other words, if a block can fully cover a row of outputs in $\tilde{y}$, there is no computational overhead for fusing two matrix multiplies.****This is ideally a requirement we meet moving forward.
This is an illustration of the $k=b$ case, in particular when $k = b = 1$. Notice that the weights A, B, and C are far smaller than the intermediate output.
Let's now look at the memory accesses for both approaches. Let's reuse our example from above to be consistent. According to our previous discussion on How to tile matrix multiplication, a tiled matrix multiplication with block size $b$ requires a total of $\frac{2mnk}{b}$ memory reads.
Running our two matrix multiplications back-to-back naively would thus require $\frac{2mnk}{b} + \frac{2mkn}{b}$ memory reads. However, we also need to account for the memory writes. The first matrix multiply writes $mn$ outputs and the second writes $mk$ outputs. This gives us a total of
$$\text{MemAcc}_\text{naive} = \frac{2mnk}{b} + \frac{2mkn}{b} + mn + mk = 4\frac{mnk}{b} + mn + mk$$
memory accesses. In the tiled approach, we fetch each block of $A$ once, as you can see in our visualization above. That's $bk$ fetches for a single $A$ block. Per block of $A$, we then fetch column blocks of $B$ and the corresponding row blocks of $C$. That's $bk$ fetches for each block of $B$ for $\lceil \frac{n}{b}\rceil$ total blocks in $B$. The same goes for $C$.
Finally, write the $b \times k$ output, which takes $bk$ writes. This gives $2bk + 2bk\lceil\frac{n}{b}\rceil$ memory accesses per block of $A$. Finally, we do this for all of $\lceil\frac{m}{b}\rceil$ blocks in $A$. This means in the tiled approach we have
$$\text{MemAcc}_\text{tiled} = \lceil\frac{m}{b}\rceil(2bk + 2bk\lceil\frac{n}{b}\rceil) \approx 2\frac{mnk}{b} + 2mk$$
memory accesses. The two expressions aren't as easily comparable, so let's figure out what this tells us. The question we want to understand is: When do we have fewer memory accesses in the tiled version than in the original? For simplicity, ignore the ceil rounding.
$$\begin{align} \text{MemAcc}_{\text{tiled}} &< \text{MemAcc}_{\text{naive}} \\[6pt] 2\frac{mnk}{b} + 2mk &< 4\frac{mnk}{b} + mn + mk & \text{like terms}\\[6pt] mk &< 2\frac{mnk}{b} + mn & \text{multiply by $\frac{b}{m}$}\\ kb &< 2nk + nb & \text{isolate n}\\ \frac{kb}{2k + b} &< n \end{align}$$
To be honest, I'm not sure what this expression says. However, let's consider the interesting case above, where $b = k$. In that case, we achieve the best-case scenario for tiled matrix multiply's computational cost. Knowing that, let's plug in $b = k$ to get
$$\begin{align}n &> \frac{k(k)}{2k + k}\\[6pt] n &> \frac{k}{3}\\[6pt] k &< 3n\end{align}$$
In short, if we're multiplying $ABC$, where $A \in \mathbb{R}^{m \times k}$, $B \in \mathbb{R}^{k \times n}$, $C \in \mathbb{R}^{n \times k}$ we need to meet the following two criteria:
- We need $b = k$ to not incur extra computational cost, for tile size $b$.
- This gives us the constraint that $k < 3n$ for there to be reduced memory accesses.
Let's now use this framework to understand the latency reduction for Flash Attention.
Why jointy tiling self-attention reduces latency
Let's consider a typical self-attention module in a Large Language Model. We covered the precise formulation in Practical Introduction to Large Language Models, where we found that a typical self-attention module has the following expression:
$$\begin{align}Y_a &= \text{Attention}(X_Q, X_K, X_V) \\ &= \text{softmax}(\frac{X_QX_K^T}{\sqrt{d}})X_V\end{align}$$
where $X_{\{Q,K,V\}} \in \mathbb{R}^{s \times d_\text{head}}$ for sequence length $s$. Using our notation of $m, n, k$ above, this means that $m = n = s$ and that $k = d_\text{head}$. Rewritten using our new variables, our constraints are now that
- $b = d_\text{head}$ or block size must match $d_\text{head}$ to avoid extra computational cost.
- $d_\text{head} < 3s$ for there to be reduced memory accesses.
This is the best part: As we know, the self-attention module satisfies both constraints.
- Per head, the dimensionality $d_\text{head}$ is usually only 64 to 128. This means setting the tile size $b = d_\text{head}$ would create tiles with $128^2 \approx 16,000$ values. Even if each value was FP32, that would be 64 KB total, which definitely fits in 96 KB of shared memory. As a result, we can feasibly satisfy the first constraint, that $b = d_\text{head}$.
- To satisfy $d_\text{head} < 3s$, we would equivalently need $s > \frac{d_\text{head}}{3} \approx 43$. In other words, we would need a sequence length of at least 43 tokens. This is also trivial to satisfy, since most Large Language Models have a maximum context length of at least 2048 — which we'll saturate at train time. As a result, we can feasibly satisfy the second constraint too, that $d_\text{head} < 3s$.
These two convenient facts make jointly tiling attention matrix multiplications an obvious choice — thus, the effectiveness of Flash Attention.
Would joint tiling work for the MLP in transformers?
A next natural question is: Can we extend this to the MLP in transformers? Maybe. At a high level, the MLP sees the opposite phenomenon: The weights are far larger than the intermediate outputs, which makes this joint-tiling strategy tougher to leverage. Let's see this using our constraints.
For the MLP, the matrix multiply inner dimension is $d = d_\text{model}$, but $d_\text{model}$ is usually fairly large, ranging in the thousands from 4096 to 8192. We wouldn't be able to fit a $b \times b = d_\text{model} \times d_\text{model} = 4096 \times 4096$ in shared memory. For context, a 4096 by 4096 contains 33 million values, so even if we quantized all those values to FP4, that would still be 8 MB, about 90x larger than shared memory. Given this, we can't satisfy the constraint that $b = d_\text{model}$.
This doesn't mean that Flash Attention is impossible to apply, but it does mean that a Flash-Attention-esque approach would incur computational overhead. We would then need to compare the extra FLOPs with potentially-reduced memory accesses — not a super clear tradeoff or win.
Conclusion
In short, we've analyzed the computational and memory cost of jointly tiling back-to-back matrix multiplies, finding a nice and neat set of constraints that can tell us roughly when joint tiling is worthwhile. We need two constraints to be met:
- The block size should match the head dimension, $b = d_\text{head}$ for transformers and $b = k$ more generally, when $y = ABC$ and $A \in \mathbb{R}^{n \times k}$. This ensures that computational cost is kept constant.
- The dimensionality of our samples should be smaller than the number of samples or $d_\text{head} < 3n$ for transformers and $k < 3n$ more generally, using the same matrices as in the last bullet.
Taking a step back, we can also say at a high level, we generally need small weight matrices and large intermediate outputs to reap savings.
← back to Guide to Machine Learning
-
Note that "joint tiling" in this particular case is more commonly called "fusing". The terms refer to slightly different concepts, which we'll cover in a later post in the series — When to fuse multiple matrix multiplies. For now, you can also interpret "joint tiling" as "fusing" — the latter of which is a more commonly used term. ↩
Want more tips? Drop your email, and I'll keep you in the loop.