from Guide to Machine Learning on Jul 7, 2024
How speculative decoding works
As we discussed in Why Large Language Model inference is memory bound, we have three general tactics for improving latency for inference
- Directly address the memory bottleneck by lowering the amount of data we need to transfer. This includes techniques like quantization and pruning.
- Directly address the memory bottleneck by reordering the workload to better reuse data. This includes techniques like tiling and Flash Attention.
- Leverage under-utilized, extra compute cycles by increasing the total amount of computation. In this post, we'll explore a technique for this last category, called "speculative decoding".
Let's hop right in.
Improving latency by batching
Let's see how batching affects arithmetic intensity of matrix multiplication. Does batching change arithmetic intensity enough, to turn a memory bound workload into a compute bound one?
In our previous discussions of tiled matrix multiplication, we deduced that arithmetic intensity is roughly $b$, the tile dimension. Although this doesn't affect our arithmetic intensity, we used $m \times n$ and $n \times k$ matrices. Lets change this setup; say the first matrix is batched with batch size $s$, making the first matrix $s \times m \times n$. Now, our calculus changes slightly.
Before, for every tile of size $b \times b$, we would load $bn$ values from the first matrix and $nb$ values from the second. Now, we can actually load $snb$ values from the first matrix, for the same $nb$ values from the second, write $sb^2$ outputs and execute $2n sb^2$ FLOPs. Our new arithmetic intensity is now
$$\frac{2nsb^2}{snb + bn + b^2} = \frac{2nsb}{ns + n + b} = \frac{2snb}{n(1+s) + b}$$
As a sanity check, note that when $s = 1$, the expression reduces to the arithmetic intensity from Why Large Language Model inference is memory bound. There isn't a particularly good way to simplify this expression, so instead, let's plot it. To do so, we'll assume fixed values for $n = 2048$ and $b = 128$, then vary the batch size $s$.
Arithmetic intensity grows quickly with increase in batch size. This rate of growth is attenuated by the inner dimension n and tile size b.
As you can see, increasing the batch size strictly improves arithmetic intensity. Critically, increasing the batch size just to 5 or more already transforms inference from a memory bound to a compute bound workload.
In summary then, batching can increase arithmetic intensity by multiple times; in this case, by a factor of 1.7x. By contrast, since $m=1$ and $k \gg 1$, the extra memory for batching the first matrix is negligible. All in all then, arithmetic intensity goes up significantly with negligible memory overhead, allowing us to increase computation "for free". This is our core insight: Batching allows us leverage unused compute "for free".
This now begs the question: How do we use the "free" compute that batching affords us, to get faster inference?
Batching autoregressive decoding
Let's say the user provides a prompt "the". The core issue for autoregressive decoding is that it's, well, autoregressive.
- Given "the", we predict "brown".
- Given "the brown", we predict "fox".
- Given "the brown fox", we predict "jumps" and so on and so forth.
As you can see above, every token we decode is conditioned on all of the previous tokens. The Large Language Model by nature requires us to decode one token at a time. Given that, how would we possibly decode tokens in batches?
The core idea is to guess tokens however you like — then, check those guesses in a single batch, using just one forward pass. Let me explain how that works.
How to guess
Let's talk about how to guess first. There are several possibilities for how to guess. We will calls these methods "approximation models", using the same terminology that the original paper did:
-
The original speculative decoding paper used a smaller Large Language Model to produce guesses. Both the small and the big models would come from the same family of models e.g., OPT-125M for the small model and OPT-65B for the big model.
- Later papers actually distill the small model from the big one, to ensure their predictions are aligned.
- Note that keeping the small model's KV cache up to date is quite tricky for this variant.
-
Other work suggests the Large Language Model as a whole can be replaced with a really large n-gram model. In a similar vein, you could use n-grams to provide guesses. These n-grams could be trained on a general corpus, on outputs generated so far, or even on the prompt.
- Honestly, any method of generating tokens works. In theory, you could even pick four random words1.
Whichever approximation method you pick, ask it to generate guesses. Say we ask our approximation model to generate 4 guesses. Perhaps it picks "the brown fox jumps". Next, we need to check: Do these guesses agree with what the original Large Language Model would have outputted?
Interlude: What Large Language Models predict
To understand how to check efficiently, we need to review what Large Language Models predict. Recall that transformers predict one output token for every input token, so if we input 4 tokens, the model outputs 4 tokens.
Let's continue our example from before. Say we input 4 tokens $x_0, x_1, x_2, x_3 =$ "the brown fox jumps". We then receive as output 4 tokens $y_1, y_2, y_3, y_4 =$ "brown crayon sleeps over". These outputs appear to make no sense, but that's normal. It's because we're reading these outputs in the wrong order actually. Instead of reading them sequentially, we need to read the outputs after the corresponding inputs, like
- $x_0, y_1$
- $x_0, x_1, y_2$
- $x_0, x_1, x_2, y_3$
- $x_0, x_1, x_2, x_3, y_4$
Plugging in our example words, we would read the following.
- the brown
- the brown crayon
- the brown fox sleeps
- the brown fox jumps over
Now the outputs make sense. Said another way, our predictions are the next word for every prefix of our input. Our input was "the brown fox jumps", and the model tells us what it would have predicted after "the", after "the brown", after "the brown fox" and after "the brown fox jumps". All 4 of these next-word predictions, in one forward pass!
In summary, using just a single forward pass, we can check if a batch of tokens is what the Large Language Model would've predicted autoregressively.
How to check
Let's now use this fact above: Recall that our approximation model guessed "the brown fox jumps". Do the guesses match what the Large Language Model would've predicted? To check this, pass the guessed sequence as input to the Large Language Model, just like above.
I've rewritten the inputs and their corresponding outputs, in tabular form — the four inputs and their labels in the first two rows, then the four outputs and their labels in the last two rows.
$x_0$ | $x_1$ | $x_2$ | $x_3$ | |
---|---|---|---|---|
the | brown | fox | jumps | |
brown | crayon | sleeps | over | |
$y_1$ | $y_2$ | $y_3$ | $y_4$ |
Now, to check the italicized input guesses in the second row, let's use the underlined output in the third row.
-
Focus on the second column.
- After the first word $x_0$ "the", we guessed $x_1$ "brown".
- After the first word $x_0$ "the", the model predicted $y_1$ "brown".
- Our guess $x_1$ matches the actual prediction $y_1$. Knowing this, we accept our guess, $x_1$.
-
Focus on the third column.
- After the first two words $x_0, x_1$ "the brown", we guessed $x_2$ "fox".
- After the first two words $x_0, x_1$ "the brown", the model predicted $y_2$ "crayon".
- Our guess $x_2$ does not match the actual prediction $y_2$. As a result, we reject our guess, $x_2$.
-
So far, we have only accepted one guess, so our sequence so far is $x_0, x_1$ "the brown". Given the third column, we know the model predicts $y_2$ "crayon" next. This is key. Even though we only accepted one token, this forward pass produced two tokens — one correct guess and one prediction.
- There is no need to check any more guesses because all of our future guesses $x_3$ and predictions $y_3, y_4$ are now based on the incorrect guess starting sequence $x_0, x_1, x_2$ "the brown fox", when instead the starting sequence is "the brown crayon".
Our complete output so far is now $x_0, x_1, y_2$, or "the brown crayon". This is pretty neat all in all: we can now check all of our guesses in a single forward pass! We then repeat this entire process — guess, obtain actual predictions using a single forward pass, then check how many guesses match actual predictions.
In the worst case, all of our guesses are wrong, and this method degenerates to "normal" autoregressive decoding, so we truly leverage the "free" property of batching; if our guesses are correct, we decode multiple tokens at once. If our guesses are wrong, we revert to normal decoding with negligible overhead. In this way, we get possibly faster inference for "free".
How much faster?
The number of accepted tokens, and the induced acceptance rate, determines the total speed up. Say we produce 4 guesses — formally, "draft tokens" — and that 50% of these are accepted. The net effect is that two draft tokens are accepted on average, every forward pass. This means three total tokens are predicted at a time, making this a 3x speedup, roughly speaking.
Takeaways
Speculative decoding works in a fairly clever way, which we need to build up intuition for, from a number of takeaways we've built up incrementally.
- Batching allows us to leverage unused compute "for free", because larger batch sizes increases arithmetic intensity fairly quickly.
- There are several ways to produce guesses, using what we call approximation models. These can be a smaller Large Language Model, n-grams, or really any way of generating words — maybe even random guesses.
- Using just a single forward pass, we can check an entire batch of guesses, using the fact that predictions are all autoregressively defined. This ensures that tokens produced by speculative decoding are identical to the tokens produced by the original Large Language Model.
Combined together, these takeaways produce a method called speculative decoding, which allows us to increase throughput "for free," given that inference is heavily memory bound.
← back to Guide to Machine Learning
-
Naturally, guessing random words is a bad idea. As you'll see in the next sections, you ideally want the approximation model to agree with the original Large Language Model. The more the two agree, the faster the speedup. ↩
Want more tips? Drop your email, and I'll keep you in the loop.