from Guide to Machine Learning on Apr 30, 2023

How to tile matrix multiplication

Matrix multiplication is a staple of deep learning and a well-studied, well-optimized operation. One of the most common optimizations for matrix multiplication is called "tiling," but as common and important as it is, it's a bit confusing to understand.

Tiling matrix multiplication is a valuable technique that optimizes resource utilization in multiple dimensions, including power, memory, and compute. Critically, tiling then reduces overall latency, making this vital for models heavily reliant on dense matrix multiplication.

One such example is transformers and their associated Large Language Models; their heavy reliance on dense matrix multiplies for inference makes tiling an important concept to understand — and to leverage.

In this post, we'll break down how tiling for matrix multiplication works, again by conveying intuition primarily through illustrations.

How does tiling work?

I'll start with a description of how to tile a single matrix multiply. Here we only cover the most salient parts at a high level.

Let's multiply two matrices $A$ and $B$ normally. To do so, we take the inner product of all the rows in $A$ and the columns in $B$. We illustrate this below.

Here's what that process looks like in more detail:

  1. Fetch the first row $A_{0,:}$ (8 fetches).
  2. Fetch the first column $B_{:,0}$ (8 fetches).
  3. Take the inner product to get one value in our output $O_{0,0}$. Repeat this for all 64 output values.

For each of the 64 values in our output, we need to fetch a total of 16 values: 8 values from $A$ and 8 values from $B$. This means we need $64 \times 16 = 1024$ total fetches.

For our first step, we can simply reuse the first row of $A$. This is pictured below, where we fetch one row of $A$ to compute the entire first row of output.

Here's that process in more detail:

  1. Fetch row $A_{0,:}$ (8 fetches) and column $B_{:,0}$ (8 fetches).
  2. Compute inner product for $A_{0,:}$ and $B_{:,0}$ to get $O_{0,0}$ like before.
  3. Now, reuse $A_{0,:}$ from before (0 fetches). Fetch the next column $B_{:,1}$ (8 fetches).
  4. Compute inner product for $A_{0,:}$ and $B_{:,1}$ to get $O_{0,1}$.
  5. Repeat this for all $B_{:,c}$ columns to get the first of outputs $O_{0,:}$. Repeat this for all rows $A_{r,:}$.

For the example above,

However, notice that each column from $B$ is only used to generate 1 output. This makes our fetches relatively inefficient.

For our next step, change the order we generate output values in. Instead of generating the first row, we generate values in the top-left quadrant.

Now, each row in $A$ is used to generate 4 outputs, and every column in $B$ is used to generate 4 outputs. This once again lowers the number of fetches we need:

This completes our matrix multiplication! To summarize:

  1. We started with
  2. Reusing rows from $A$ reduced that cost to 576 fetches.
  3. Generating outputs by block instead of by row reduced the cost to 256 fetches.

That's a reduction of 4x, from 1024 to 256 fetches — a much more efficient matrix multiply.

How effective is tiling?

Above, we noted that tiled matrix multiplication reduced memory accesses by 4x, with a block size of 4x4 — from 1024 to 256 fetches. We also noted previously that we reduce memory accesses by 2x, with a block size of 2x2 — from 1024 to 512 fetches.

Notice the pattern? For every block size $b \times b$, we reduce the number of fetches by $b$. Here's the intuition: In the original matrix multiplication, every row and column vector is used to generate only one output value, every time it is fetched. In the tiled matrix multiplication, every row and column vector is used to generate $b$ output values, every time it is fetched. This is why we reduce the number of total number of fetches by $b$, intuitively.

To make this more concrete, let's count the number of fetches with general matrix dimensions. Say $A \in \mathbb{R}^{m \times k}, B \in \mathbb{R}^{k \times n}$. Then,

  1. Our original matrix multiplication costs $k + k$ fetches for every one of the $mn$ outputs. This makes $2mnk$ fetches.
  2. Say our block size is $b$. Our tiled matrix multiplication uses $b$ rows of $k$ values each and $b$ columns of $k$ values each. That makes $2bk$ fetches for each $b \times b$ block, where there are $(\frac{m}{b})(\frac{n}{b}) = \frac{mn}{b^2}$ total blocks. This makes $2bk\frac{mn}{b^2} = \frac{2mnk}{b}$ fetches.

Notice this cost for tiled matrix multiplication is exactly the original matrix multiply's cost $2mnk$ divided by $b$, proving what we had empirically noticed above. A block size of$b$will reduce the total number of memory accesses by$b$. The bigger the block, the better.

Why "tiling" is so fast

Tiling leverages several principles to make matrix multiplication run really fast:

  1. Parallelization: We can calculate the results of each output block independently, so if there are 4 blocks like in our example, we can run 4 threads concurrently.
  2. Better memory management: We saw a small snippet of this above — tiling matrix multiplication reduces the number of memory accesses. As we'll see below, there are more than one ways tiling actually improves speed, however.

Before diving into auxiliary benefits for tiling, let's understand why memory fetches are so important for latency:

So, 3.8ms to move the weights and 1.2ms to actually compute using the weights. Knowing this, memory bandwidth is the bottleneck4.

Now, here's how tiling a matrix multiplication helps. If we use a block size of 4, we could reduce the number of memory accesses by 4x, reducing the time spent on moving weights from 3.8ms to 0.95ms. With that said, we can now also parallelize computation across blocks, so time spent computing would also decrease.

We can continue to increase the block size arbitrarily to keep reducing latency. However, you might then wonder: What stops us from increasing the block size indefinitely? This is where memory constraints come in.

Why there's a limit to tiling

In short, your hardware may not be able to store all the weights you need in memory.

For example, let's say you're operating on the same 4096x4096 matrices from before, now with output block size 8x8. To compute a single output block, we need 8 rows of 4096 values and 8 columns of 4096 values. Altogether, this is $(4096 \times 8 + 4096 \times 8) \times 2 = 131072$ or 131KB of data. Unfortunately, the V100 only has 96 KB of shared memory at maximum5. We now have two options:

  1. We can reduce the output block size from 8x8 to 4x4. This requires just $(4096 \times 4 + 4096 \times 4) \times 2 = 65536$ or 66KB, which fits in shared memory. However, this increases the number of memory accesses by 2x.
  2. We can take the loss, accepting that our fetched values don't fit in 96 KB shared memory and will now sit in the 6MB L2 cache. However, L2 cache load speed is 2 TB/s, which is 6x slower than shared memory load speed at 13 TB/s6.

Neither option is very desirable, so in effect, shared memory size limits our output block size. Fortunately, we have one more trick.

How to sidestep tiling limits

Let's consider the limits that shared memory imposes. Let's go back to our original example with 4x4 output blocks, for an 8x8 matrix. Here's a visual representation of our tiled matrix multiplication, which we introduced before.

Notice we need to simultaneously hold values from several matrices in shared memory: 32 values from $A$, 32 values from $B$ and 16 values of the output $O$. If all our values are in FP16, this is $(32 + 32 + 16) \times 2 = 160B$. This isn't bad at all, but let's say that our shared memory can only hold a little more than half as much — just 96B.

To accommodate this stricter memory constraint, we can fetch only part of a row and only part of a column.

  1. Fetch only the first half of each row and column. This gives us 4x4 subsets of $A$ and $B$. Take the matrix product to obtain a 4x4 output $O$. However, we haven't yet fully computed the outputs.
  2. Fetch the second half of each row and column, which are again 4x4 subsets of $A$ and $B$. Take the matrix product to obtain a 4x4 output $O$. Add these 4x4 outputs to the 4x4 outputs from the previous step.

Here is the process visualized.

Here's the process in more detail.

  1. Fetch the top-left quadrant for both matrices, $A_{:4,:4}$ and $B_{:4,:4}$. There are only 32 values used for computation, which is 64B.
  2. Perform a naive matrix multiply on this pair of 4x4 matrices to obtain a 4x4 output, which we store in $O_{:4,:4}$. this is 16 values or 32B of outputs. Notice at this point that although each output value is filled, it isn't "complete". For example, if we look at the top-left output, its value is $A_{0,0}B_{0,0} + A_{0,1}B_{1,0} + A_{0,2}B_{2,0} + A_{0,3}B_{3,0}$. In other words, it's missing "half" of its value, which is $A_{0,4}B_{4,0} + A_{0,5}B_{5,0} + A_{0,6}B_{6,0} + A_{0,7}B_{7,0}$. We compute this other "half" next.
  3. Fetch the top-right quadrant for our first matrix $A_{:4,4:}$ and the bottom-left quadrant for our second matrix $B_{4:,:4}$. There are again only 32 values used for computation, which is 64B.
  4. Perform a naive matrix multiply on this pair of 4x4 matrices to obtain a 4x4 output, which we accumulate in the same block of output, $O_{:4,:4}$. Now, that block of output is complete.

Notice that at any point, we used only at maximum 96B, 32 input values taking 64B and 16 output values taking 32B. This satisfies our constraint for fitting into shared memory. We can continue to repeat this process for all blocks in the output to complete our matrix multiplication.

In short, reading block by block allowed us to compute with fewer inputs, meaning lower memory requirements, while still preserving the same number of fetches. In summary:

  1. We started with requiring 160B of shared memory, storing 64 total input values and 16 total output values, and 16 writes.
  2. By blocking the input, we require just 96B of shared memory, storing only 32 total input values and 16 output values at any given time, but use 32 writes.

More generally, if we again assume $A \in \mathbb{R}^{m \times k}, B \in \mathbb{R}^{k \times n}$. Then,

  1. We started with $b$ rows of $k$ values and $b$ columns of $k$ values, making $2bk$ input values. We also stored $b^2$ outputs. This is $b(2k + b)$ values or $2b(2k + b)$ total bytes, assuming half-precision.
  2. If we block the input into chunks of $b \times \ell$, then we have $2b\ell$ input values and still $b^2$ outputs. This is $b(2\ell + b)$ values or $2b(2\ell + b)$ bytes.

Notice the only difference is we exchanged $k$ for $\ell$ in our memory consumption.

We have another natural question then: Why not make use smaller and smaller input block sizes indefinitely? The only limitation is the number of writes. Namely, with block size $b \times \ell$, you incur $\frac{k}{\ell}$ writes.

Conclusion

There are several moving pieces here for actually configuring and using tiling in practice; we mentioned these factors above:

  1. Memory subsystems: There are several memory subsystems, each level gets larger but slower. On a V100, we have 96KB of shared memory with load speed of 13TB/s, 6MB L2 cache at 2.2TB/s, 32GB of global memory8 at 900GB/s or 1 TB of CPU RAM7 at 94GB/s. Each level gets 6-10x slower but also 60-5000x larger in size.
  2. Block sizes: In theory, our input matrices $A$ and $B$ sit in global memory to start, which is relatively slow to read from. As a result, since the output block size linearly decreases the number of memory accesses, we want large output block sizes. Simultaneously, to accommodate memory constraints (ideally to fit our block into shared memory), we want small input block sizes. Note that smaller input block sizes will incur more writes.

There are also several values we computed, related to tile matrix multiplies. Let $A \in \mathbb{R}^{m \times k}, B \in \mathbb{R}^{k \times n}$:

This leads us to our final takeaway: Tiling matrix multiplications reduces memory accesses and memory usage by blocking outputs and inputs respectively; this ultimately results latency overall for a memory-bandwidth limited operation.

This concludes our discussion on tiling matrix multiplies. To generalize to more matrix multiplies, see When to tile two matrix multiplies.

To see an implementation of tiling in action, see Triton's matrix multiplication tutorial, where you'll build a custom CUDA kernel from the comfort of Python. There are also many other well-written resources on this topic, if you'd like to explore alternative explanations.


back to Guide to Machine Learning



  1. Among other contributions of course. Flash Attention proposes tiling the softmax activation as well by simply storing the denominator accumulated so far. 

  2. The statistics for memory and compute bandwidth are taken directly from Nvidia's official V100 specs. For now, I'm assuming the matrices are stored in FP16 and have dimensions divisible by 8, which meets the requirements for Nvidia's tensor cores. This is generally true of massive matrix multiplies in transformers. This is why we look at "Tensor Performance" on that page. 

  3. This is generally true of large language models of around 7 billion parameters, from LlaMa to Pythia to Cerebras — 3 attention matrices of 4096x4096 each for 31 layers. 

  4. There's a bit more nuance to memory being the limiting factor. Loading weights actually bottlenecks compute, as the compute depends on certain rows and columns being loaded. With that said, you can pipeline computation to a limited degree. 

  5. Listed as "shared memory" on page 10 of the V100 architecture whitepaper

  6. Zhe Jia et al. note in "Dissecting the NVIDIA Volta GPU Architecture via Microbenchmarking" that the V100 L2 cache has a load speed of 2155 GB/s (page 23) or 2.2 TB/s. Shared memory has a load speed of 12,080 GiB/s (page 19) or $12080 \times 1.07374 = 13.0 \times 10^3$ GB/s, which is 13 TB/s. 

  7. Using memory bandwidth for an arbitrary CPU — in this case, an Intel Core X-series processor has a theoretical bandwidth of 94 GB/s

  8. "Global memory" is what CUDA uses to refer to VRAM, or video random-access memory, on a GPU. We can also generically call this "GPU RAM".