Speculative Decoding for Transformers

tl;dr: I summarized the speculative decoding method for transformers and elaborated on why it is correct.

Speculative Decoding

I recently came across the paper, Fast Inference from Transformers via Speculative Decoding. It borrows the idea from speculative execution in CPUs and proposes a new sampling algorithm, called speculative decoding, which enables parallel sampling from LLM outputs and thus speeds up inference.

The background for this technique is that traditionally, to sample N tokens from an LLM, we need to run the LLM serially N times, which is extremely inefficient. However, with the help of speculative decoding, the sampling speed can be greatly improved. Moreover, it is compatible with a wide range of existing techniques (such as top-p, top-k, beam search, and so on).

Main Idea

  • Use a smaller model to approximate the output of a larger model. Specifically, a certain number of output tokens are sampled from the smaller model and then adjusted based on whether they match the target model’s distribution.
  • This is feasible when ample computational resources are available, as determining whether the output from the smaller model matches that of the larger model requires the larger model to be run in parallel on each prefix corresponding to each output.

Correctness

The correctness argument is presented on page 3, section 3.2 “Calculating α\alpha,” and in Appendix A.1. The following is an excerpt from the correctness proof:

We will now show that for any distributions p(x) p(x) and q(x) q(x) , the tokens sampled via speculative sampling from p(x) p(x) and q(x) q(x) are distributed identically to those sampled from p(x) p(x) alone. Let β \beta be the acceptance probability (Definition 3.1).

Note that since p(x)=norm(max(0,p(x)q(x)))=p(x)min(q(x),p(x))x(p(x)min(q(x),p(x)))=p(x)min(q(x),p(x))1β p'(x) = \text{norm}(\max(0, p(x) - q(x))) = \frac{p(x) - \min(q(x), p(x))}{\sum_{x'} (p(x') - \min(q(x'), p(x')))} = \frac{p(x) - \min(q(x), p(x))}{1 - \beta} , the normalizing constant for the adjusted distribution p(x) p'(x) is 1β 1 - \beta , where the last equation follows directly from Lemma 3.3 and Theorem 3.5.

Now:

P(x=x)=P(guess accepted,x=x)+P(guess rejected,x=x) P(x = x') = P(\text{guess accepted}, x = x') + P(\text{guess rejected}, x = x')

Where:

P(guess accepted,x=x)=q(x)min(1,p(x)q(x))=min(q(x),p(x)) P(\text{guess accepted}, x = x') = q(x') \min\left(1, \frac{p(x')}{q(x')}\right) = \min(q(x'), p(x'))

And:

P(guess rejected,x=x)=(1β)p(x)=p(x)min(q(x),p(x)) P(\text{guess rejected}, x = x') = (1 - \beta) p'(x') = p(x') - \min(q(x'), p(x'))

Overall:

P(x=x)=min(q(x),p(x))+p(x)min(p(x),q(x))=p(x) P(x = x') = \min(q(x'), p(x')) + p(x') - \min(p(x'), q(x')) = p(x')

As desired. □

In particular, P(guess rejected,x=x) P(\text{guess rejected}, x = x') initially puzzled me. By definition, β \beta is the expectation (over the distribution q(x) q(x) ) of the rejection probability. So how does it relate to P(guess rejected,x=x) P(\text{guess rejected}, x = x') ?

β=Exq(x){1q(x)p(x)p(x)q(x)q(x)>p(x)=Exq(x)min(1,p(x)q(x))=xmin(p(x),q(x)) \beta = E_{x \sim q(x)} \begin{cases} 1 & q(x) \leq p(x) \\ \frac{p(x)}{q(x)} & q(x) > p(x) \end{cases} = E_{x \sim q(x)} \min\left(1, \frac{p(x)}{q(x)}\right) = \sum_x \min(p(x), q(x))

It turns out we can expand this probability formula as follows:

P(guess rejected,x=x)=P(guess rejected,x=x,q(x=x1))++P(guess rejected,x=x,q(x=xn))=iP(guess rejected,x=x,q(x=xi)) \begin{align*} &\quad P(\text{guess rejected}, x = x')\\ &= P(\text{guess rejected}, x = x', q(x = x_1)) + \dots + P(\text{guess rejected}, x = x', q(x = x_n))\\ &= \sum_i P(\text{guess rejected}, x = x', q(x = x_i)) \end{align*}

where

P(guess rejected,x=x,q(x=xi)) P(\text{guess rejected}, x = x', q(x = x_i))

denotes the probability that the guess was rejected when xi x_i was sampled from q(x) q(x) . This expansion is necessary because when each xiq(x) x_i \sim q(x) is rejected, it triggers a resampling from p(x) p'(x) , which can result in x=x x = x' being selected.

Thus, we can complete the equation:

iP(guess rejected,x=x,q(x=xi))=iq(xi)(1min(1,p(xi)q(xi)))p(x)=p(x)(1iq(xi)min(1,p(xi)q(xi)))=(1β)p(x) \begin{align*} &\quad \sum_i P(\text{guess rejected}, x = x', q(x = x_i))\\ &= \sum_i q(x_i) \left(1 - \min\left(1, \frac{p(x_i)}{q(x_i)}\right)\right) p'(x')\\ &= p'(x') \left(1 - \sum_i q(x_i) \min\left(1, \frac{p(x_i)}{q(x_i)}\right)\right)\\ &= (1 - \beta) p'(x') \end{align*}

where

1min(1,p(xi)q(xi)) 1 - \min\left(1, \frac{p(x_i)}{q(x_i)}\right)

is the probability that xiq(x) x_i \sim q(x) was rejected.

Simulation

You can experiment with simulation and concrete probability calculations using the script here to verify that the algorithm is indeed correct.

More

If you’re interested in this topic, check out this well-curated repository surveying all speculative decoding strategies. It includes approaches not only for transformers but for many other models as well. I’m not an expert (not even a beginner) in this area, so that’s all I can say for now.