from Guide to Machine Learning on Jul 7, 2024

How speculative decoding works

As we discussed in Why Large Language Model inference is memory bound, we have three general tactics for improving latency for inference

  1. Directly address the memory bottleneck by lowering the amount of data we need to transfer. This includes techniques like quantization and pruning.
  2. Directly address the memory bottleneck by reordering the workload to better reuse data. This includes techniques like tiling and Flash Attention.
  3. Leverage under-utilized, extra compute cycles by increasing the total amount of computation. In this post, we'll explore a technique for this last category, called "speculative decoding".

Let's hop right in.

Improving latency by batching

Let's see how batching affects arithmetic intensity of matrix multiplication. Does batching change arithmetic intensity enough, to turn a memory bound workload into a compute bound one?

In our previous discussions of tiled matrix multiplication, we deduced that arithmetic intensity is roughly $b$, the tile dimension. Although this doesn't affect our arithmetic intensity, we used $m \times n$ and $n \times k$ matrices. Lets change this setup; say the first matrix is batched with batch size $s$, making the first matrix $s \times m \times n$. Now, our calculus changes slightly.

Before, for every tile of size $b \times b$, we would load $bn$ values from the first matrix and $nb$ values from the second. Now, we can actually load $snb$ values from the first matrix, for the same $nb$ values from the second, write $sb^2$ outputs and execute $2n sb^2$ FLOPs. Our new arithmetic intensity is now

$$\frac{2nsb^2}{snb + bn + b^2} = \frac{2nsb}{ns + n + b} = \frac{2snb}{n(1+s) + b}$$

As a sanity check, note that when $s = 1$, the expression reduces to the arithmetic intensity from Why Large Language Model inference is memory bound. There isn't a particularly good way to simplify this expression, so instead, let's plot it. To do so, we'll assume fixed values for $n = 2048$ and $b = 128$, then vary the batch size $s$.

<{http://www.w3.org/1999/02/22-rdf-syntax-ns#}RDF> <{http://creativecommons.org/ns#}Work> <{http://purl.org/dc/elements/1.1/}type {http://www.w3.org/1999/02/22-rdf-syntax-ns#}resource="http://purl.org/dc/dcmitype/StillImage"> <{http://purl.org/dc/elements/1.1/}date>2024-11-20T08:42:27.433107 <{http://purl.org/dc/elements/1.1/}format>image/svg+xml <{http://purl.org/dc/elements/1.1/}creator> <{http://creativecommons.org/ns#}Agent> <{http://purl.org/dc/elements/1.1/}title>Matplotlib v3.8.0, https://matplotlib.org/ 0 2 4 6 8 10 Batch size () s 0 50 100 150 200 Arithmetic Intensity 0 124 167 189 202 211 217 222 226 229 Batched Matrix Multiplication, d_model=2048, tile_size=128 Arithmetic intensity grows quickly with increase in batch size. This rate of growth is attenuated by the inner dimension n and tile size b.

As you can see, increasing the batch size strictly improves arithmetic intensity. Critically, increasing the batch size just to 5 or more already transforms inference from a memory bound to a compute bound workload.

In summary then, batching can increase arithmetic intensity by multiple times; in this case, by a factor of 1.7x. By contrast, since $m=1$ and $k \gg 1$, the extra memory for batching the first matrix is negligible. All in all then, arithmetic intensity goes up significantly with negligible memory overhead, allowing us to increase computation "for free". This is our core insight: Batching allows us leverage unused compute "for free".

This now begs the question: How do we use the "free" compute that batching affords us?

Batching autoregressive decoding

The core issue for autoregressive decoding is that it's, well, autoregressive.

As you can see above, every token we decode is conditioned on all of the previous tokens. The Large Language Model by nature requires us to decode one token at a time. Given that, how would we possibly decode tokens in batches?

The core idea is to guess tokens however you like — then, check those guesses in a single batch, using just one forward pass.

How to guess

Let's talk about how to guess first. There are several possibilities for how to guess. We will calls these methods "approximation models", using the same terminology that the original paper did:

  1. The original speculative decoding paper used a smaller Large Language Model to produce guesses. Both the small and the big models would come from the same family of models e.g., OPT-125M for the small model and OPT-65B for the big model.

    1. Later papers actually distill the small model from the big one, to ensure their predictions are aligned.
    2. Note that keeping the small model's KV cache up to date is quite tricky for this variant.
  2. Other work suggests the Large Language Model as a whole can be replaced with a really large n-gram model. In a similar vein, you could use n-grams to provide guesses. These n-grams could be trained on a general corpus, on outputs generated so far, or even on the prompt.

  3. Honestly, any method of generating tokens works. In theory, you could even pick four random words1.

Whichever approximation method you pick, ask it to generate guesses. Say we ask our approximation model to generate 4 guesses. Perhaps it picks "the brown fox jumps". Next, we need to check: Do these guesses agree with what the original Large Language Model would have outputted?

Interlude: What Large Language Models predict

To understand how to check efficiently, we need to review what Large Language Models predict. Recall that transformers predict one output token for every input token, so if we input 4 tokens, the model outputs 4 tokens.

Let's continue our example from before. Say we input 4 tokens $x_0, x_1, x_2, x_3 =$ "the brown fox jumps". We then receive as output 4 tokens $y_1, y_2, y_3, y_4 =$ "brown crayon sleeps over". These outputs appear to make no sense, but that's normal. It's because we're reading these outputs in the wrong order actually. Instead of reading them sequentially, we need to read the outputs after the corresponding inputs, like

Plugging in our example words, we would read the following.

Now the outputs make sense. Said another way, our predictions are the next word for every prefix of our input. Our input was "the brown fox jumps", and the model tells us what it would have predicted after "the", after "the brown", after "the brown fox" and after "the brown fox jumps". All 4 of these next-word predictions, in one forward pass!

How to check

Let's now use this fact above: Recall that the inputs "the brown fox jumps" are just guesses. Do the guesses match what the Large Language Model would've predicted? Here are the guessed tokens and their corresponding outputs.

$x_0$ $x_1$ $x_2$ $x_3$
the brown fox jumps
brown crayon sleeps over
$y_1$ $y_2$ $y_3$ $y_4$

Now, let's use the output in the third row, underlined, to check the input guesses in the second row, italicized.

  1. After the first word $x_0$ the, we guessed $x_1$ brown. After the first word $x_0$ the, the model predicted $y_1$ brown. Our guess matches the actual prediction. Knowing this, we accept our guess, $x_1$.
  2. After the first two words $x_0, x_1$ the brown, we guessed $x_2$ fox. After the first two words $x_0, x_1$ the brown, the model predicted $y_2$ crayon. Our guess does not match the actual prediction. As a result, we reject our guess, $x_2$.
  3. We know the model would have predicted $y_2$ crayon next, so we add that to our list of tokens so far. This is key. This forward pass produced two tokens — one correct guess and one new prediction. Generally speaking, we take all of the accepted tokens and the next token predicted.
  4. There is no need to check any more guesses because all of our future guesses $x_3$ and predictions $y_3, y_4$ are now based on the incorrect guess $x_2$ fox.

Our complete output so far is now $x_0, x_1, y_2$, or "the brown crayon". This is pretty neat all in all: we can now check all of our guesses in a single forward pass! We then repeat this entire process — guess, obtain actual predictions using a single forward pass, then check how many guesses match actual predictions.

In the worst case, this method degenerates to "normal" autoregressive decoding, so we truly leverage the "free" property of batching; if our guesses are correct, we decode multiple tokens at once. If our guesses are wrong, we revert to normal decoding with negligible overhead. In this way, we get possibly faster inference for "free".

Takeaways

Speculative decoding works in a fairly clever way, which we need to build up intuition for, from a number of takeaways we've built up incrementally.

Combined together, these takeaways produce a method called speculative decoding, which allows us to increase throughput "for free," given that inference is heavily memory bound.


back to Guide to Machine Learning



  1. Naturally, guessing random words is a bad idea. As you'll see in the next sections, you ideally want the approximation model to agree with the original Large Language Model. The more the two agree, the faster the speedup.