from Guide to Machine Learning on May 14, 2023

When to fuse multiple matrix multiplies

Optimizing a single matrix multiplication is well-studied. In fact, optimizing a pair of matrix multiplications has also been explored. What about optimizing a sequence of more than two matrix multiplies?

We showed previously how to optimize one and even two consecutive matrix multiplies by reducing memory accesses. Our next question is: How do we optimize more than two matrix multiplies? When is optimization across more than two matrix multiplies beneficial, and when is it not?

Principles for joint tiling

Let's start by repeating our previous approach of jointly tiling matrix multiplies. We'll again consider the same pair of matrix multiplies $y = ABC$ where $A \in \mathbb{R}^{m \times k}, B \in \mathbb{R}^{k \times n}, C \in \mathbb{R}^{n \times k}$.

As we saw in When to tile two matrix multiplies, the general case requires recomputation. As a result, we restrict our attention to the case where $b = k$, or in other words, when an entire column of $B$ fits into a block.

As a refresher, this is what $b = k$ looks like. In particular, we illustrate $b = k = 1$ for simplicity. Here, A and C are both narrow matrices. B is a wide matrix.

This time, we observe two interesting properties that enabled us to jointly tile, in the previous post:

  1. Intermediate blocks were fully computed. Although we computed the intermediate output part by part, each part was fully computed when it was used. This is pictured above. To compute block (0, 0) in the intermediate output, we only needed block (0,) in A and (0,) in B.
  2. Intermediate blocks were only used once. Notice that every block in the first row of intermediate outputs only contributes to one final output block. This occurs only because C is a narrow matrix, making the final output narrow as well.

Both properties allowed us to complete each block in the final output without any recomputation. Let's now try to extend these properties to more matrix multiplies.

Attempt #1: Change matrix sizes

Previously, our trick was to change matrix sizes so that joint tiling became worthwhile. Let's try that again now: If we have a third matrix multiply, how can we change matrix sizes to preserve the two properties above?

Say we now have three consecutive matrix multiplies $z = ABCD$ where $D \in \mathbb{R}^{k \times \ell}$. Just like before, we need to set $\ell = b$ to ensure that we use each block in the previous intermediate output $y = ABC$ only once; this also ensures that our computational cost does not increase.

Unfortunately, we also previously set $k = b$, so this leaves us with $D \in \mathbb{R}^{b \times b}$ being a fairly tiny and insignificant weight matrix. In fact, this constraint carries forward for any matrix multiply we add, so any additional matrix multiplies are degenerate.

Enforcing the constraint $k = b = \ell$ creates a set of degenerate additional matrix multiplies. Unfortunately, we need this constraint to both keep computational cost from growing and maintain the second property — that all intermediate outputs are used only once.

More importantly, it's unlikely that this degenerate set of matrix multiplies would be applicable in any real-world setting. As a result, we turn our attention towards another approach.

Approach #2: Change program order.

Another possibility is to change program order to prevent recomputation. This means we compute intermediate outputs, then use this intermediate output to compute all relevant final outputs.

We saw above that keeping the same matrix sizes for $A, B, C$ resulted in a degenerate $D$ matrix, so we'll simultaneously expand $D$ into a tall matrix and $C$ into a rectangular matrix. We'll now use a slightly modified version of $z = ABCD$, where now $C \in \mathbb{R}^{n \times \ell}, D \in \mathbb{R}^{\ell \times b}$. Here's what that looks like in more detail:

  1. Compute the first row of $\tilde{y} = AB$. Our goal is now to reuse this computed row.
  2. Load the first column of $C$, and take its inner product with first row of $\tilde{y}$ to obtain the first block in $y$. Then, take its inner product with the first block in $D$, accumulating the result in the first block of the final output $z$.
  3. Repeat this for all columns of $C$ and blocks in $D$ to finish the first block of output $z$.

This design solves our recomputation issue but it presents a different issue: We now have to store an entire first row of $\tilde{y}$ intermediate outputs, unlike previous approaches that only store a single $b \times b$ block at any point in time. This severely limits the amount of work that a single kernel can perform and furthermore makes exploiting cache locality more difficult.

There are other ways to change program order as well, but many options lead to either increased memory consumption or recomputation.

Generalizing from tiling to fusion

In short, there isn't a clear way to jointly tile more than two matrix multiplies. This results in either additional degenerate matrices or recomputation. This is by no means an exhaustive list of possibilities, but let's turn our attention to some other low-hanging fruit. Instead of tiling, let's focus on a more general tactic of fusing operations.

What it means to fuse matrix multiplies

Let's first recap our process for jointly tiling two matrix multiplies:

  1. Say we have two matrix multiplications.
  2. Partially compute the first matrix multiplication, to obtain a part of the intermediate output.
  3. Immediately use this partial intermediate output. Notice we didn't have to fully compute and store the intermediate output.
  4. Partially compute the second matrix multiplication, to obtain a part of our final output.
  5. Repeat this process for all parts of the input until we complete the final output.

In this way, we never represent the intermediate output explicitly, saving memory. This technique is more generally called fusing operations. Fusion generally breaks down an operation into independent subproblems, so that we can compute and use small parts of intermediate outputs, rather than computing and storing entire intermediate outputs.

What we did before — tiling — is a specific way of breaking down a matrix multiplication into subproblems; as a result, we can still fuse matrix multiplies without specifically tiling them.

Now the question is: How do we fuse without tiling? We need to find a different way to break down matrix multiplication into subproblems. Said another way, how do we "partially compute" the first matrix multiplication?

How to fuse multiple matrix multiplies

The idea is simple: If our input is very large with many samples, we can simply compute on subset of samples. In this way, our independent subproblems are simply matrix multiplies on a subset of samples. There are two papers from UC Berkeley that took this approach, with slightly different variations:

In short, BPT's effectiveness boils down to:

  1. Are the weights larger than the inputs? If so, don't use BPT. Run all inputs through the model at once, so we load weights fewer times.
  2. Are the inputs larger? If so, chop up inputs so that each chunk fits in higher-bandwidth memory (e.g., SRAM instead of DRAM); then, we can read input activations from and write output activations to SRAM when computing each layer.

These two papers enable higher throughput and effectively faster inference when amortized across many samples, by "fusing" multiple matrix multiplies.

When to fuse multiple matrix multiplies

Let's now focus in on the case where BPT is beneficial. How large of a context length or batch size would we need to realize benefits? As we said above, BPT provides benefits most obviously when the activation size is larger than the weight size.

  1. Weight size: Using the formulas from Practical Introduction to Large Language Models, we know that each decoder block uses $12d_\text{model}^2$ parameters, where $d_\text{model}$ is the dimensionality of each token.
  2. Activation size: We can use grab expressions for intermediate activations from that post, which gives us $7ncd_\text{model}$ total activations1, where $n$ is the batch size and $c$ is the context length of each input sample.

Let's say you're currently auto-regressively generating the output, one at a time. At this point, your batch size is effectively $n=1$, so for activation size to exceed weight size, we would need context length to satisfy

$$c > \frac{12}{7n}d_\text{model}$$

, which is about 7,000 for LlaMA 7B, where $d_\text{model} = 4096$. Once activation size exceeds weight size, it is now potentially worthwhile to load weights multiple times so that your batched inputs and intermedaite activations can fit into SRAM.

However, if you use a batch size greater than $n=1$ for, say, training, then this number decreases proportionally. For a batch size of 8, you would only need a context length of ~900 on LLaMA 7B, for example2.

Conclusion

In sum, jointly tiling more than two matrix multiplies is non-trivial, but fusing more than two matrix multiplies is simple: Break up the input into batches along the sample and context length dimensions. To determine if this will help latency, use the following inequality:

$$c > \frac{12}{7n}d_\text{model}$$

Then, apply these two rules:

  1. If this inequality is satisfied, your activation size exceeds your parameter size. In which case, take smaller subsets of samples ($n$) or context ($c$) until your activations fit in SRAM. The largest activation will be $ncd_\text{model}$. On an A100, this means ensuring $ncd_\text{model} < 40 \times 10^6$ for a 40 MB L2 cache.
  2. If this inequality is not satisfied, your parameter size exceeds your activation size. Since it's highly unlikely that your parameters will ever fit in SRAM, in LLM land, minimize the number of times you need to load your parameters by passing your inputs in all at once.

With these general heuristics, you should now be able to fuse multiple matrix multiplies in Large Language Models.


back to Guide to Machine Learning



  1. In the original post, we wrote $7nd_\text{model}$ total activations, where $n$ was actually the batch size multiplied by the context length. While copying that expression into this post, I re-wrote "n" as $nc$ to make this point clearer. 

  2. All of this assumes that your weights and activations have the same datatype. This calculus changes if your weights are quantized, for example. In which case, add the number of bytes that weights and activations consume, per value, to each side of the equation.