Introduction
Over the past few years, large language models (LLMs) – most recently ChatGPT – have received lots of (well-deserved) press. Though they have their shortcomings, they are able to compose shockingly cogent prose, and the quality appears to increase the bigger the models themselves become.
One aspect that I think is often overlooked by much of the public is the (computational, which implies financial) cost of actually turning the model’s crank to produce text. Soon after ChatGPT was released, it was estimated that at 1 million users it was costing OpenAI around 3 million dollars per month in cloud compute costs. With 100 million users (assuming nothing else has changed), this would cost 300 million dollars per month, or 3.6 billion dollars per year!
One of the themes we’ve observed over and over in technology over the past few decades is that when a disruptive, but costly, technological advance emerges, the cost of producing – or in this case, serving – that technology typically declines in response to increased motivation to profitably unleash the technology’s capabilities.
In this post, I want to describe a relatively simple and intuitive technique from a paper by DeepMind that might be a first step in the direction of bringing down the cost of operating the types of large language models that I believe will become ubiquitous in the years and decades to come.
(As a note, below I will use the word “expensive” a lot. This word refers to computational expense, but computational expense is directly relatable to financial expense, so you can think of it that way too if you’d like.)
Background
At a very high level, many large language models (such as GPT and friends) generally produce text using autoregressive sampling, which is a fancy term for sampling used previously generated (hence autoregressive) text to produce a probability distribution (hence sampling) over possible next words. To understand what this means, suppose your vocabulary has three words: “apple,” “banana,” and “carrot.” A distribution over these three words in essence assigns a probability to obtaining each word (the probabilities have to add up to 1) if you were to sample randomly from them. (There are infinitely many possible distributions you could choose. Usually, assumptions about distributions have to be reasonable, not correct.)
The distribution over your vocabulary words is determined by the model at each step. In some sense, the model produces a distribution that makes sense given the text you’ve already generated (and/or a prompt that you wrote). If you had already generated the partial sentence “I will wear a raincoat because it is going to,” a good model would produce a distribution over your vocabulary that indicates a very high probability on a word like “rain,” and a very low probability on a word like “spinach” (which contextually doesn’t fit).
The takeaway from this can be summarized as follows:
- These models are very large, so each step (word) is expensive to compute.
- Because the samples have to be produced sequentially, they require many LLM steps.
Put simply, many steps x high cost = very high cost!
The New Idea
DeepMind attempts to tackle (2) from the previous section by reducing the number of inference steps required of the very large (and very expensive) model while maintaining the high quality of the tokens; it sounds like free lunch! How do they do this?
The basic idea is that given some number of previous words, we:
- Use a smaller, less expensive model to generate a candidate sequence of a certain length.
- Use the big model to score the words generated by the small model. (The scores here can be thought of as measures of approval of the draft tokens by the big model.)
- Use the word scores to decide how much of the sequence to use.
- Rinse and repeat until the sequence is of the desired length.
If you stop reading here, you’ve learned the important idea. Lately, I’ve been trying to keep posts very high level, but in this case, the fact that the importance-to-complexity ratio of this idea is very high, I’m going to break pattern and go into some more technical detail in the sections below.
Speculative Sampling
We will now discuss the algorithm in more detail. The below steps are carried out until the sequence is of the desired length.
Step 1: Generate and Score a Draft
Draft generation must be sequential, but the scoring can happen in parallel. (This is one speed-up source.) In this context, scoring a token means computing the probability of that token occurring given the current sequence and the already generated draft tokens. This is where the algorithm’s speed-up comes from. The draft is generated using a smaller, less expensive model (the draft model), and the scoring – which requires appealing to the large model – can happen in parallel.
Step 2: Deciding How Much of the Sequence to Accept
Next, the algorithm requires figuring out how much of the sequence to accept. This takes the form of accepting each successive token produced by the draft model with some probability (that depends on the prior tokens that have already been accepted). Once we decide not to accept a token, we sample a token from some distribution (we will think through a good one to use below) and start fresh with a new draft. In my opinion, choosing the right probability and the right alternative distribution is where the algorithm’s cleverness is really on display. To be specific about what we aim to disambiguate in this section, there a two questions we seek to answer:
- What probability $r$ should we use to decide whether or not to accept the next token?
- What alternative distribution should we use if the $i$th token of the draft is not accepted?
Before addressing these questions, we should state this algorithm’s overall objective a little bit more precisely: We want the speculatively sampled sequence to come from the same distribution as the sequence we would get if we autoregressively sampled a sequence from the large expensive model.
Another item to clarify is that we can, in some sense, discuss models and distributions interchangeably. Here, a model is a way of taking a sequence and producing probabilities that each token in the vocabulary is the next one in the sequence. We can thus refer to and talk about models like we refer to distributions. To this end, let $q(x \mid x_1, x_2, \dots, x_i)$ be the expensive model, and $p(x \mid x_1, x_2, \dots, x_i)$ be the draft model. As in the paper, we will also refer to the $t$th draft token as $\tilde x_t$.
I will first state the answers to questions (1) and (2), and then I will show that they work. The probability $r$ that we use is given by the expression
and the rejection distribution is
These probabilities and distributions depend on the combined initial sequence and already accepted draft tokens. Also, note that if $q(x \mid x_1, \dots, x_{n+t-1}) > p(x \mid x_1, \dots, x_{n+t-1})$, we automatically accept the $t$th draft token. Intuitively, this checks out, because if the draft model produces a token which, given the prior sequence tokens, is more likely to have been produced by the large model than the draft model, of course we should use it!
If instead we have $q(x \mid x_1, \dots, x_{n+t-1}) \leq p(x \mid x_1, \dots, x_{n+t-1})$, then we accept $\tilde x_t$ with a probability that is larger when $q$’s score is close to $p$’s. If $p$ gives $\tilde x_t$ a score of 0.36 and $q$ gives it a score of 0.12, for example, then we will accept $\tilde x_t$ with probability 0.12 / 0.36 = 1/3. Alternatively, if $p$ gives a score of 0.36 and $q$ gives a score of 0.0001, we will likely not accept the token because the target and draft models really disagree about whether $\tilde x_t$ would make a good next token.
(Recall from earlier that we computed the $q$ scores for the draft tokens in parallel during Step 1.)
If we accept $\tilde x_t$, then we take another step and consider whether or not to accept $\tilde x_{t+1}$. If, on the other hand, we reject $\tilde x_t$, we sample from the complicated looking distribution we spelled out in equation (1). What we want is a distribution that re-weights the possible tokens to sample from the large expensive model in a sensible way. In our case, $q - p$ (in the numerator) will produce negative “probabilities” for tokens where $q < p$, and positive values where $q > p$. This kind of makes sense, because if $\tilde x_t$ is rejected, we want to favor sampling tokens to which $q$ assigns higher scores than $p$ does. We have two problems, though:
- Probabilities have to be nonnegative
- Probabilities have to sum to 1
But these are actually no problem at all! To solve (1), we modify $q-p$ to $\max(0, q-p)$, and solve (2), we use the standard normalization trick: dividing by the sum! (For example, if I wanted to make the list [1, 2, 3, 4, 5] into a probability distribution, I would divide each element by the sum to obtain [1/15, 2/15, 3/15, 4/15, 5/15].) Once we make both of those modifications, we obtain equation (1), which we shorten to the expression on the left side of the equals sign.
So far, we have resorted to intuition to motivate our choices, but it turns out that the two choices fit together like elegant mathematical puzzle pieces. How this happens goes back to our objective, which is to come up with a sample that looks as though it was obtained strictly using the target (expensive model). Let’s see if we can show that our strategy helps accomplish this.
Proving that we recover the target distribution
If we have two discrete distributions $a$ (target) and $b$ (draft) and a draft sample $x’ \sim b$ ($\sim$ means “sampled from”), let $X$ be a random variable representing the sample produced by speculative sampling. If $X$ ends up taking on a specific value $x$, there are two possible ways it could have happened:
- We accepted $x’$, in which case $x’ = x$.
- We rejected $x’$, in which case $x \sim (a - b)_+$.
Outcome 1
The probability of outcome (1) is the probability that the draft sample $x’$ is accepted given that it takes the particular value $x$. We have to multiply this by the probability that the draft distribution assigns to the event that $x’$ takes on that value, so we have
The probability of sampling the value of $x$ from the distribution $b$ is simply $b(x)$. The probability of accepting it is $\min(1, a(x) / b(x))$ (from the algorithm specification). We thus have
Outcome 2
On the other hand, if $x’$ is rejected, then the probability that $X$ takes the value $x$ is the probability of sampling $x$ from $(a(x) - b(x))_+$. By our definition of that distribution, we would have
We need to weight this outcome by the probability that $x’$ is rejected, which is given by
In the above sequence of equalities, the second uses the fact that to get the $p(x)$, we can sum up $p(x, y)$ over all possible values of $y$. The third equality reuses our computation from Outcome 1. The fourth uses the fact that the 1 outside the summation can be broken into probabilities $a(\hat x)$ for all possible values of $\hat x$, since probabilities must sum to 1. Finally, the last equality follows when you flip $-\min(a(\hat x), b(\hat x))$ to $\max(-a(\hat x), -b(\hat x))$ and then add the $a(\hat x)$ to both of the arguments to $\max$.
Now that we’ve worked all of the details out, does the last expression look familiar? It is the denominator of $P(X = x \mid x’~\text{rejected})$! Multiplying our two probabilities together, we have
Putting Them Together
Now that we’ve computed probabilities for both options, we note that the two possibilities are mutually exclusive and exhaustive ways that $X$ can take the value $x$. Thus, the probability $P(X = x)$ is given by
Now, if $a(x) > b(x)$, then the first term is $b(x)$ and the second term is $a(x) - b(x)$. Adding these together, we get $a(x)$. If $a(x) \leq b(x)$, then the first term is $a(x)$ and the second term is 0, so again the sum is $a(x)$. Thus, speculative sampling recovers the target distribution $a(x)$. In other words, the rejection sampling technique we’ve devised produces sequences of tokens that are theoretically indistinguishable from the very expensive target model!
Conclusion
While this technique is just one step towards making LLMs more efficient, it highlights the potential return on further innovation in the space of faster LLM inference. As the number of LLM applications continues to explode, we can expect even more creative solutions to emerge, hopefully making these powerful tools more accessible and affordable for everyone.
If not speculative sampling itself, methods in the same spirit will become more necessary and important as we continue to push the boundaries of size and scale in generative models. I thought this technique was worth illuminating because of its simple, yet powerful, theoretically grounded choices. In many deep learning applications, systems often seem like quasi-magical feats of engineering whose designers don’t even always know why they work as well as they do. In reading DeepMind’s speculative sampling paper, I found the technique’s simplicity and mathematical rigor refreshing.