1. Introduction to Speculative Decoding

Given a score model S (for example, LLAMA-3-70B) and a draft model D (for example, LLAMA-3-7B), the process of speculative decoding can be described as follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
input_ids = Tensor(...)  # (seq_len,)

while True:
    # D generates tokens[seq_len, ..., seq_len + k]
    draft_outputs = D(input_ids)  # (k,)
    # Given tokens[seq_len - 1, ..., seq_len + k], S generates real 
    # prediction for tokens[seq_len, ..., seq_len + k, seq_len + k + 1] with
    # one forward pass.
    score_outputs = S(cat(input_ids, draft_outputs))  # (k + 1,)
    i = 0
    for i in range(k):
        if not verify(draft_outputs[i], score_outputs[i]):
            break
        input_ids.append(draft_outputs[i])
    input_ids.append(score_outputs[i]) 
Speculative decoding workflow in vLLM (k=3, top-p=1). k=3 indicates that the draft model generates 3 tokens per forward pass, and top-p=1 means that for each token, only 1 candidate is proposed. As shown in the picture, at prefill statge, input sequence would first be fed into both draft and score models to acquire kv caches. The output of draft model at this stage is omitted. Then, T5 is fed into draft model to generate proposed T6', T7', and T8'. To verify these tokens, T5, T6', T7' and T8' are fed into the score model to get T6, T7*, T8* and T9* in one forward pass. Note that here T6 must be correct because it is generated by T5 through the score model; However, T7*, T8* and T9* are not guaranteed to be correct. The final step is to verify T6', T7' and T8' to see if T7*, T8* and T9* are correct. For example, if T6' and T7' is correct, then the final accepted tokens would be T6', T7' and T8', which means the socore model generates 3 tokens in one forward pass.
Workflow of spuculative decoing in vLLM (k=1, top-p=1). Like the previous picture, if T6' is correct, then the final accepted tokens would be T6' and T7*, one generated by the draft model and the other by the score model. The score model generates 2 tokens in one forward pass.

2. How Speculative Decoding Works in vLLM

In vLLM, speculative decoding is integrated with the system’s continuous batching architecture, where different requests are processed together in a single batch, enabling higher throughput. vLLM uses two key components to implement this:

  • Draft Runner: This runner is responsible for executing the smaller proposer model to propose candidate tokens.
  • Target Runner: The target runner verifies the tokens by running the larger scorer model.

vLLM’s system is optimized to handle this process efficiently, allowing speculative decoding to work seamlessly with continuous batching, which increases the overall system performance.

Diagram illustrating how the draft and target runners interact within the vLLM batching system.

To implement speculative decoding in vLLM, two crucial components had to be modified:

  • Scheduler: The scheduler was adjusted to handle multiple token slots within a single forward pass, enabling the simultaneous generation and verification of several tokens.
  • Memory Manager: The memory manager now handles the KV cache for both the draft and scorer models, ensuring smooth processing during speculative decoding.
System architecture of speculative decoding in vLLM.

3. Types of Speculative Decoding Supported in vLLM

3.1. Draft Model-Based Speculative Decoding

This is the most commonly used form of speculative decoding, where a smaller model predicts the next tokens, and a larger model verifies them. A common example would be using a Llama 68M model to predict tokens for a Llama 2 70B model. This approach requires careful selection of the draft model to balance accuracy and overhead.

Choosing the correct draft model is essential for maximizing the efficiency of speculative decoding. The draft model needs to be small enough to avoid creating significant overhead but still accurate enough to provide a meaningful performance boost.

However, selecting the right draft model can be challenging. For example, in models like Llama 3, finding a suitable draft model is difficult due to differences in vocabulary size. Speculative decoding requires that the draft and target models share the same vocabulary, and in some cases, this can limit the use of speculative decoding. Therefore, in the following sections, we introduce several draft-model free speculative decoding methods.

3.2. Prompt Lookup Decoding

An example of prompt lookup decoding. Given the prompt, we build all 2-grams as the lookup key. The values are the three tokens following the lookup key. During generation, we will check if the current 2-gram matches any key. If so, we will propose the following tokens with the value.

Otherwise known as n-gram matching, this approach is effective for use cases like summarization and question-answering, where there is a significant overlap between the prompt and the answer. Instead of using a small model to propose tokens, the system speculates based on the information already available in the prompt. This works particularly well when the large model repeats parts of the prompt in its answers.

4. MEDUSA

4.1. Roadmap

  1. [vllm][ISSUE] | Can vLLM support medusa head? #1023
  2. [vllm][ISSUE] | [Discussion] Will vLLM consider using Speculative Sampling to accelerating LLM decoding? #1171
  3. [vllm][PR] | [Speculative Decoding] Medusa Implementation with Top-1 proposer #4978

4.1. MEDUSA Heads

MEDUSA heads are additional decoding heads appended to the last hidden states of the original model.

Three heads are used to propose tokens for the following three positions. Head 1 is proposing ["is", "\'", "the"] for the first position. Head 2 is proposing ["difficult", "is", "\'"] for the second position. Head 3 is proposing ["not", "difficult", "a"] for the third position. NOTE: All heads take the output of the last transformer block as the input.

Specifically, given the original model’s last hidden states $h_t$ at position $t$, we add $K$ decoding heads to $h_t$. The $k$-th head is used to predict the token in the $(t + k + 1)$-th position of the next tokens (the original language model head is used to predict the $(t + 1)$-th position).

$$ \begin{aligned} p_{t}^{(k)} & =\mathrm{softmax}\left(W_{2}^{(k)}\cdot\left(\mathrm{SiLU}(W_{1}^{(k)}\cdot h_{t})+h_{t}\right)\right), \\ & \mathrm{where~}W_{2}^{(k)}\in\mathbb{R}^{d\times V},W_{1}^{(k)}\in\mathbb{R}^{d\times d}. \end{aligned} $$

Unlike a draft model, MEDUSA heads are trained in conjunction with the original backbone model, which can remain frozen during training (MEDUSA-1) or be trained together (MEDUSA-2).

4.2. Tree Attention

The top-2 predictions from the first MEDUSA head and the top-3 from the second result in a total of $2 \times 3 = 6$ candidates. Each of these candidates corresponds to a distinct branch within the tree structure.

To guarantee that each token only accesses its predecessors, an attention mask is devised that exclusively permits attention flow from the current token back to its antecedent tokens.

5. EAGLE

5.1. Roadmap

  1. [vllm][PR] | [Speculative Decoding] EAGLE Implementation with Top-1 proposer #6830

5.2. Detailed Process

A comparison of the methods for drafting the fourth and fifth tokens, t4 and t5. t (represented by blue blocks) denotes tokens, and f (orange blocks) signifies the features, with subscripts indicating their positions in the sequence. The red border indicates the predictions of the draft model. For simplicity, the n in the n-gram for Lookahead, as shown in the figure, has been set to 2.

This link is a Feishu drawboard to show the detailed process of speculative decoding with EAGLE in vLLM:

6. DeepseekMTP

Structure of DeepseekMTP. This figure also demonstrates the training process of draft models, which are fed with continuous tokens and corresponding masks to predict the next tokens for each position. This process is similar to the pre-training process of the larger scorer model.
Compute graph of DeepseekMTP.

7. Discussion

7.1. Performance Insights, Speedups, and Trade-offs

Ref: [vllm] | How Speculative Decoding Boosts vLLM Performance by up to 2.8x

Speculative decoding offers significant performance benefits in low-QPS (queries per second) environments. For example, in testing on the ShareGPT dataset, vLLM demonstrated up to a 1.5x speedup in token generation when using draft model-based speculative decoding. Similarly, prompt lookup decoding has shown speedups of up to 2.8x when applied to summarization datasets, such as CNN/DailyMail.

Performance comparison showing spec decode delivering up to 1.5x Speedup at QPS=1 Llama3-70B on ShareGPT with 4xH100 using draft model (turboderp/Qwama-0.5B-Instruct) and up to 2.8x Speedup at QPS=1 Llama3-70B on CNN Dailymail with 4xH100 using n-grams.

However, in high-QPS environments, speculative decoding may introduce performance trade-offs. The extra compute required to propose and verify tokens can sometimes slow down the system when it is already compute-bound, as seen when the number of requests per second increases. In such cases, the overhead of speculative decoding can outweigh its benefits, leading to reduced performance.

As high QPS, we see 1.4x slowdown Llama3-70B on ShareGPT with 4xH100, 1.8x slowdown Llama3-70B on CNN Dailymail with 4xH100

7.2. Why exactly is batch expansion inefficient?

Ref: Optimizing attention for spec decode can reduce latency / increase throughput

Looking at Llama2 architecture, each component has the following algorithmic complexity wrt speculative tokens and sequence length. The baseline is non-speculative decoding, so factors such as d_model are ignored as they are the same in either case.

Each of these scales linearly with number of speculative tokens, except for attention, which scales by num_spec_tokens * seq_len. This means that for large batch sizes and/or large speculative trees and/or large sequence lengths, attention will be the computational bottleneck.

To optimize the attention operation, the key is that components of the attention operation are duplicated when scoring different speculative tokens given the same prefix sequence:

Speaking theoretically, we can optimize attention for speculative scoring by reducing redundant QK^T computations + loads and Softmax(...)V loads:

  • Share K loads for common tokens
  • Share K*Q compute for common tokens
  • Share V loads for common tokens

We should experimentally verify this analysis: one weakness is that Softmax(...)V computation is still O(num_spec_tokens * seq_len).

References

  1. [vllm] | Speculative Decoding
  2. [vllm] | How Speculative Decoding Boosts vLLM Performance by up to 2.8x
  3. [vllm] | How to Use Speculative Decoding in vLLM.
  4. [vllm][PR] | [Speculative Decoding] Medusa Implementation with Top-1 proposer #4978
  5. A Hitchhiker's Guide to Speculative Decoding
  6. [vllm] | What is lookahead scheduling in vLLM?
  7. Optimizing attention for spec decode can reduce latency / increase throughput
  8. [vllm][ISSUE] | [RFC]: Automate Speculative Decoding #4565
  9. [HF] | Faster Assisted Generation with Dynamic Speculation