Medusa: a simpler, more user-friendly framework for accelerating LLM generation. Instead of using an additional draft model like speculative decoding, Medusa merely introduces a few additional decoding heads, following the idea of [Stern et al. 2018] with some other ingredients. Despite its simple design, Medusa can improve the generation efficiency of LLMs by about 2x.
Why is LLM Generation Inefficient?
From a systems perspective, LLM generation follows a memory-bound computational pattern with the main latency bottleneck arising from memory reads/writes rather than arithmetic computations. This issue is rooted in the inherently sequential nature of the auto-regressive decoding process. Each forward pass necessitates the transfer of the entire model's parameters from High-Bandwidth Memory (HBM) to the accelerator's compute units. This operation, while only producing a single token for each sample, fails to fully utilize the arithmetic computation capabilities of modern accelerators, resulting in inefficiency.
Before the rise of LLMs, a common mitigation(缓解) for this inefficiency was to simply increase the batch size, enabling the parallel production of more tokens. But the situation becomes far more complicated with LLMs. Increasing the batch size in this context not only introduces higher latency but also substantially inflates(增加) the memory requirements for the Transformer model's key-value cache. This trade-off makes the use of large batches impractical for many applications where low latency is a critical requirement.
Is Speculative Decoding the Ultimate Solution?
Given the challenges outlined, one appealing strategy to accelerate text generation is more efficient computational utilization—specifically, by processing more tokens in parallel. This is precisely where speculative decoding comes into play. The methodology employs a streamlined "draft" model to generate a batch of token candidates at each step quickly. These candidates are then validated by the original, full-scale language model to identify the most reasonable text continuations. The underlying logic(潜在的) hinges (依赖于)on an intriguing assumption: the draft model, although smaller, should be proficient enough to churn out sequences that the original model will find acceptable.
If this assumption holds, the draft model can rapidly produce token sequences while the original model efficiently vets(仔细检查) multiple tokens in parallel—thereby maximizing computational throughput. Recent research indicates that with a well-tuned draft model, speculative decoding can cut latency by an impressive factor of up to 2.5x.
However, the approach is not without its challenges:
- Finding the Ideal Draft Model: Identifying a "small yet mighty" draft model that aligns well with the original model is easier said than done.
- System Complexity: Hosting two distinct models in one system introduces layers of complexity, both computational and operational, especially in distributed settings.
- Sampling Inefficiency: When doing sampling with speculative decoding, an importance sampling scheme needs to be used. This introduces additional overhead on generation, especially at higher sampling temperatures.
These complexities and trade-offs have limited the broader adoption of speculative decoding techniques. So, while promising, speculative decoding isn't widely adopted.
Remark: We use speculative decoding to refer to those methods that require an independent draft model here. In a broader sense, our method can also be viewed as speculative decoding, while the draft model is entangled(卷入) with the original model.
这里的并行思路很新奇,因为原来模型是接收seq后,一个个产生token,每个token的产生基本可以说是线性的顺序产生的,而现在speculative的思路就是:我们拿一个“draft” 模型能够快速的产生许多候选的token,然后我们用更复杂的模型去筛选找出合适的,这个筛选找的过程就不需要是线性的,可以并行执行。因为拿着已有的token,我们只需要计算比如draft给的是 “你”“好”“我”“是”,那么可以并行计算在 “你”这个序列下,下一个token是否是“好”,在“你好我”这个序列下,下一个token的结果是否是“是”,只要找到一个序列能满足(acceptable)那么我们就能输出结果了,这就是让他并行化的方法(我理解的)。而这个并行化的核心就是依赖于一个“有趣的推测”:虽然draft model很小,但是仍然足够去产生原始的大的模型找到的满足attention结果的序列。
但是它存在的三个挑战的确让它没办法很好的流行开
Medusa: Marrying Simplicity with Efficiency
When it comes to speculative decoding, its promise is counteracted by its inherent complexity. Recognizing the need for a more user-friendly yet potent solution, we proudly present Medusa. This innovative framework not only accelerates LLM generation but also makes LLM technology accessible and usable for a broader audience. Our approach revisits an underrated gem from the paper "Blockwise Parallel Decoding for Deep Autoregressive Models" [Stern et al. 2018] back to the invention of the Transformer model: rather than pulling in an entirely new draft model to predict subsequent tokens, why not simply extend the original model itself? This is where the "Medusa heads" come in. These additional decoding heads seamlessly integrate with the original model, producing blocks of tokens at each generative juncture. Unlike the draft model, Medusa heads can be trained in conjunction with the original model, which remains frozen during training. This method allows for fine-tuning large models on a single GPU, taking advantage of the powerful base model's learned representations. Additionally, since the new heads consist of just a single layer akin to the original language model head, Medusa does not add complexity to the serving system design and is friendly to distributed settings.
On its own, Medusa heads don't quite hit the mark of doubling processing speeds. But here's the twist: When we pair this with a tree-based attention mechanism, we can verify several candidates generated by Medusa heads in parallel. This way, the Medusa heads' predictive prowess(能力) truly shone through, offering a 2x to 3x boost in speed.
We didn't stop there, either. Eschewing(避免) the traditional importance sampling scheme, we created an efficient and high-quality alternative crafted specifically for the generation with Medusa heads. This new approach entirely sidesteps(避免) the sampling overhead, even adding an extra pep(活力) to Medusa's already accelerated step.
In a nutshell, we solve the challenges of speculative decoding with a simple system:
- No separate model: Instead of introducing a new draft model, we train multiple decoding heads on the same model.
- Simple integration to existing systems: The training is parameter-efficient so that even GPU poor can do it. And since there is no additional model, there is no need to adjust the distributed computing setup.
- Treat sampling as a relaxation: Relaxing the requirement of matching the distribution of the original model makes the non-greedy generation even faster than greedy decoding.
The figure below offers a visual breakdown of the Medusa pipeline for those curious about the nuts and bolts.
Overview of Medusa. Medusa introduces multiple heads on top of the last hidden states of the LLM, enabling the prediction of several subsequent tokens in parallel.
When augmenting a model with Medusa heads, the original model is frozen during training, and only the Medusa heads undergo fine-tuning. This approach makes it feasible to fine-tune large models on a single GPU.
During inference, each head generates multiple top predictions for its designated position. These predictions are assembled into candidates and processed in parallel using a tree-based attention mechanism. The final step involves utilizing a typical acceptance scheme to select reasonable continuations, and the longest accepted candidate prefix will be used for the next decoding phase.
The efficiency of the decoding process is enhanced by accepting more tokens simultaneously, thus reducing the number of required decoding steps.
Let's dive into the three components of Medusa: Medusa heads, tree attention, and typical acceptance scheme.
Medusa heads
So what exactly are Medusa heads? They're akin to the language model head in the original architecture (the last layer of a causal Transformer model), but with a twist: they predict multiple forthcoming tokens, not just the immediate next one. Drawing inspiration from the Blockwise Parallel Decoding approach, we implement each Medusa head as a single layer of feed-forward network, augmented with a residual connection.
Training these heads is remarkably straightforward. You can either use the same corpus that trained the original model or generate a new corpus using the model itself. Importantly, during this training phase, the original model remains static; only the Medusa heads are fine-tuned. This targeted training results in a highly parameter-efficient process that reaches convergence swiftly—especially when compared to the computational heaviness of training a separate draft model in speculative decoding methods.
The efficacy of Medusa heads is quite impressive. On the Vicuna models we test, Medusa heads achieve a top-1 accuracy rate of approximately 60% for predicting the 'next-next' token. Yet, there is still room for improvement.
Tree attention
During our tests, we uncovered some striking metrics: although the top-1 accuracy for predicting the 'next-next' token hovers around 60%, the top-5 accuracy soars to over 80%. This substantial increase indicates that if we can strategically leverage the multiple top-ranked predictions made by the Medusa heads, we can significantly amplify the number of tokens generated per decoding step.
With this goal, we first craft a set of candidates by taking the Cartesian product(笛卡尔积) of the top predictions from each Medusa head. We then encode the dependency graph into the attention following the idea from graph neural networks so that we can process multiple candidates in parallel.
For example, let's consider a scenario where we use top-2 predictions from the first Medusa head and top-3 predictions from the second, as illustrated below. In this case, any prediction from the first head could be paired with any prediction from the second head, culminating in a multi-level tree structure. Each level of this tree corresponds to predictions from one of the Medusa heads. Within this tree, we implement an attention mask that restricts attention only to a token's predecessors, preserving the concept of historical context. By doing so and by setting positional indices for positional encoding accordingly, we can process a wide array of candidates simultaneously without needing to inflate the batch size.
We would also remark that a few independent works also adopt very similar ideas of tree attention [1, 2]. Compared with them, our methodology leans towards a simpler form of tree attention where the tree pattern is regular and fixed during inference, which enables a preprocessing of tree attention mask that further improves the efficiency.
Tree Attention. This visualization demonstrates the use of tree attention to process multiple candidates concurrently. As exemplified, the top-2 predictions from the first Medusa head and the top-3 from the second result in 2*3=6 candidates. Each of these candidates corresponds to a distinct branch within the tree structure. To guarantee that each token only accesses its predecessors, we devise an attention mask that exclusively permits attention flow from the current token back to its antecedent tokens. The positional indices for positional encoding are adjusted in line with this structure.
Typical acceptance
In earlier research on speculative decoding, the technique of importance sampling was used to generate diverse outputs closely aligned with the original model's predictions. However, later studies showed that this method tends to become less efficient as you turn up the "creativity dial," known as the sampling temperature.
In simpler terms, if your draft model is just as good as your original model, you should ideally accept all its outputs, making the process super efficient. However, importance sampling will likely reject this solution in the middle.
In the real world, we often tweak the sampling temperature just to control the model's creativity, not necessarily to match the original model's distribution. So why not focus on just accepting plausible candidates?
We then introduce the typical acceptance scheme. Drawing inspiration from existing work on truncation sampling, we aim to pick candidates that are likely enough according to the original model. We set a threshold based on the original model's prediction probabilities, and if a candidate exceeds this, it's accepted.
In technical jargon, we take the minimum of a hard threshold and an entropy-dependent threshold to decide whether to accept a candidate as in truncation sampling. This ensures that meaningful tokens and reasonable continuations are chosen during decoding. We always accept the first token using greedy decoding, ensuring that at least one token is generated in each step. The final output is then the longest sequence that passes our acceptance test.
What's great about this approach is its adaptability. If you set the sampling temperature to zero, it simply reverts to the most efficient form—greedy decoding. When you increase the temperature, our method becomes even more efficient, allowing for longer accepted sequences, a claim we've confirmed through rigorous testing.
So, in essence, our typical acceptance scheme offers a more efficient way to generate the creative output of LLMs.