from Guide to Machine Learning on May 7, 2023

When to tile two matrix multiplies

Matrix multiplication is an extremely well-studied and well-optimized operation. However, what if you have two matrix multiplies, instead of just one?

What's less obvious is how to optimize successive matrix multiplies.Flash Attention proposes a mechanism for accomplishing this in the n=2 case, jointly tiling a pair of matrix multiplies that are applied directly one after another — using this, Flash Attention is able to speed up language model training speed by over 3x**! This begs the question: How do we jointly optimize two matrix multiplies? More generally, when is jointly optimizing multiple matrix multiplies beneficial, and when is it not?

How to optimize two matrix multiplies

Previously, we showed how to make a single matrix multiplication faster and more memory-efficient, using tiling. Effectively, we accomplished this by breaking down the matrix multiply into smaller, independent subproblems.

In this post, we'll extend this to two matrix multiplies, making both faster and more memory efficient. In particular, we show how to jointly tile two matrix multiplies, by again breaking down two matrix multiplies into smaller, independent subproblems. Here's what this means in more detail:

  1. Say we have two matrix multiplications.
  2. Partially compute the first matrix multiplication, to obtain a part of the intermediate output.
  3. Immediately use this partial intermediate output. Notice we didn't have to fully compute and store the intermediate output.
  4. Partially compute the second matrix multiplication, to obtain a part of our final output.
  5. Repeat this process for all parts of the input until we complete the final output.

In this way, we never represent the intermediate output explicitly, saving unnecessary reads and writes to memory; in turn, this saves latency. This is what we call "joint tiling1" of two matrix multiplies. Let's now discuss how this works in more detail.

How to tile two matrix multiplies

Say we have two matrix multiplications back-to-back, such as $y = ABC$. We could start by applying the above matrix multiplication twice. Unfortunately, there are quite a few fetches and writes:

  1. Fetch $A$ and $B$. Compute. Write the results to $\tilde{y}$.
  2. Fetch $\tilde{y}$ and $C$. Compute. Write the results to $y$.

Notice the extra reads and writes incurred for storing an intermediate $\tilde{y}$ result. Our goal is to eliminate those extra reads and writes, instead writing directly to $y$, without ever explicitly storing $\tilde{y}$.

Unfortunately, these savings aren't free; we need to recompute parts of our intermediate tensors multiple times and fetch our weight matrices repeatedly, in exchange.

To understand why, consider a single value of $y = \tilde{y}C$, our final output. That value requires a row vector in our intermediate output $\tilde{y}$ and a column vector in $C$.

These two facts combined mean that we increase compute and memory accesses for our weight matrices, in exchange for fewer memory accesses for intermediate outputs.

This illustrates the intuitive takeaway: Jointly tile matrix multiplies if the intermediate outputs are larger than the weight matrices. Now, let's compute exactly when this tradeoff is favorable.

When to tile two matrix multiplies

Let's start with the computational cost for both approaches — the naive one with two separate matrix multiplications and the jointly-tiled one with two concurrently tiled matrix multiplies.

Say we wish to compute $\tilde{y} = AB \in \mathbb{R}^{m \times n}$ where $A \in \mathbb{R}^{m \times k}, B \in \mathbb{R}^{k \times n}$. In this case, for every one of the $mn$ outputs, we have $k$ multiplies and $k-1$ additions. This makes a total of $mn(2k - 1)$ total FLOPs. We'll abbreviate this to be $2mnk$ FLOPs.

Say we add another matrix multiplication so that $y = (AB)C \in \mathbb{R}^{m \times k}$ where $C \in \mathbb{R}^{n \times k}$. The naive approach would then incur

$$\text{FLOPs}_\text{naive} = 2mnk + 2mkn = 4mnk$$

FLOPs. In the jointly-tiled approach, we recompute each row of the intermediate output $\tilde{y}$ a total of $\lceil \frac{k}{b} \rceil$ times, for block size $b$. In other words, we recompute once for every block in a row of the final output $y$. This means in the fused approach, we have

$$\text{FLOPs}_\text{tiled} = 2mnk\lceil\frac{k}{b}\rceil + 2mkn = 2mnk(\lceil \frac{k}{b} \rceil + 1)$$

FLOPs. Notice that $b \leq k$, meaning the computational cost for tiling can only increase. Jointly tiled and naive FLOPs are identical when $k = b$. In other words, if a block can fully cover a row of outputs in $\tilde{y}$, there is no computational overhead for fusing two matrix multiplies.****This is ideally a requirement we meet moving forward.

This is an illustration of the $k=b$ case, in particular when $k = b = 1$. Notice that the weights A, B, and C are far smaller than the intermediate output.

Let's now look at the memory accesses for both approaches. Let's reuse our example from above to be consistent. According to our previous discussion on How to tile matrix multiplication, a tiled matrix multiplication with block size $b$ requires a total of $\frac{2mnk}{b}$ memory accesses. Running our two matrix multiplications back-to-back naively would thus require $\frac{2mnk}{b} + \frac{2mkn}{b}$ memory accesses. However, we also need to account for the memory accesses for writing and reading the intermediate output, which in this case is $2mn$ total accesses. This gives us a total of

$$\text{MemAcc}_\text{naive} = \frac{2mnk}{b} + \frac{2mkn}{b} + 2mn = 2mn(\frac{2k}{b} + 1)$$

memory accesses. In the tiled approach, we re-fetch $B$ once for every block of rows in the intermediate outputs $\tilde{y}$ — or equivalently, once for every block in a column of intermediate outputs $\tilde{y}$. We additionally don't need to read or write to intermediate outputs. This means in the tiled approach we have

$$\text{MemAcc}_\text{tiled} = \frac{2mnk}{b}\lceil\frac{n}{b}\rceil + \frac{2mkn}{b} = \frac{2mnk}{b}(\lceil\frac{n}{b}\rceil + 1)$$

memory accesses. The two expressions aren't as easily comparable, so let's figure out what this tells us. The question we want to understand is: When do we have fewer memory accesses in the tiled version than in the original? For simplicity, ignore the ceil rounding.

$$\begin{align}\text{MemAcc}_\text{tiled} &< \text{MemAcc}_\text{naive}\\2mn(\frac{2k}{b} + 1) &< \frac{2mnk}{b}(\lceil\frac{n}{b}\rceil + 1)&\text{drop }2mn\\\frac{2k}{b} + 1 &< \frac{k}{b}(\lceil\frac{n}{b}\rceil + 1)&\text{multiply }b\\ 2k + b &< k\lceil\frac{n}{b}\rceil + 1&\text{multiply }b\\2kb + b^2 &< kn + b\\b^2 + 2b(k-1) - kn &< 0 & k \gg 1\\b^2 + 2bk - kn &< 0\end{align}$$

To be honest, I'm not sure what this expression says. However, let's consider the interesting case above, where $b = k$. In that case, we achieve the best-case scenario for tiled matrix multiply's computational cost. Knowing that, let's plug in $b = k$ to get

$$\begin{align}k^2 + 2(k)k - kn &< 0\\3k^2 - kn&< 0\\3k - n &< 0 \\k &< \frac{n}{3}\end{align}$$

In short, we need to meet the following two criteria:

  1. We need $b = k$ to not incur additional computational cost.
  2. This gives us the constraint that $k < \frac{n}{3}$ for there to be reduced memory accesses.

Let's now use this framework to understand the latency reduction for Flash Attention.

Why jointy tiling self-attention reduces latency

Let's consider a typical self-attention module in a Large Language Model. We covered the precise formulation in Practical Introduction to Large Language Models, where we found that a typical self-attention module has the following expression:

$$\begin{align}Y_a &= \text{Attention}(X_Q, X_K, X_V) \\ &= \text{softmax}(\frac{X_QX_K^T}{\sqrt{d}})X_V\end{align}$$

where $X_{\{Q,K,V\}} \in \mathbb{R}^{n \times d}$. This means that our updated constraints are

  1. $b = d$ or block size must match $d$ to avoid extra computational cost.
  2. $d < \frac{n}{3}$ for there to be reduced memory accesses.

This is the best part: As we know, the self-attention module satisfies both constraints.

  1. $d$ is usually very small. Per head, the dimensionality is usually only 64 to 128, certainly plausible to fit in 96 KB of shared memory. This means setting block size $b = d$ is reasonable to do.
  2. $n$ is also fairly large during training time or batched inference, as this dimension includes both samples and token length. 16 samples, each with 64 tokens, already makes $n = 1024$, easily exceeding the requirement of $d < \frac{n}{3}$.

These two convenient facts make jointly tiling attention matrix multiplications an obvious choice — thus, the effectiveness of Flash Attention.

Would joint tiling work for the MLP in transformers?

A next natural question is: Can we extend this to the MLP in transformers? Maybe. At a high level, the MLP sees the opposite phenomenon: The weights are far larger than the intermediate outputs, which makes this joint-tiling strategy tougher to leverage. Let's see this using our constraints:

  1. Our first constraint is that $b = d$. Unfortunately, for the MLP, the dimensionality $d$ of our input is fairly large, ranging in the thousands from 4096 to 8192. Although still possible to fit in shared memory, we would only be able to fit a few rows.
  2. Our second constraint is that $d < \frac{n}{3}$. Again, $d$ is already fairly large, meaning that $n$ would need to be even larger. Say we have a batch size of 16 like before. This constraint would then require a sequence length of $\frac{3d}{64} = 768$ to be beneficial.

This isn't impossible and in fact, it's likely to happen during training. It's just not as obvious a winning scenario, as attention is.

Conclusion

In short, we've analyzed the computational and memory cost of jointly tiling back-to-back matrix multiplies, finding a nice and neat set of constraints that can tell us roughly when joint tiling is worthwhile. We need two constraints to be met:

Taking a step back, we can also say at a high level, we generally need small weight matrices and large intermediate outputs to reap savings.


back to Guide to Machine Learning



  1. Note that "joint tiling" in this particular case is more commonly called "fusing". The terms refer to slightly different concepts, which we'll cover in a later post in the series — When to fuse multiple matrix multiplies. For now, you can also interpret "joint tiling" as "fusing" — the latter of which is a more commonly used term.