We discover that each attention head in the model consistently focuses on specific prompt attention features during generation. Meanwhile, this robust pattern can be obtained from an ‘observation’ window located at the end of the prompts. Drawing on this insight, SnapKV automatically compresses KV caches by selecting clustered important KV positions for each attention head. Our approach significantly reduces the growing computational overhead and memory footprint when processing long input sequences. Specifically, SnapKV achieves a consistent decoding speed with a 3.6x increase in generation speed and an 8.2x enhancement in memory efficiency compared to the baseline when processing inputs of 16K tokens. At the same time, it maintains comparable performance to the baseline models across 16 long sequence datasets. Moreover, SnapKV can process up to 380K context tokens on a single A100-80GB GPU using HuggingFace implementation with minor changes, exhibiting only a negligible accuracy drop in the Needle-in-a-Haystack test. Further comprehensive studies suggest SnapKV’s potential for practical applications.
There are many approaches to mitigate these problems, such as KV cache eviction during generation stage [5–8]. However, most of these methods lack a detailed evaluation in long-context settings. Moreover, they mainly focus on compressing the KV cache appended during decoding steps, while overlooking the realistic problem of compressing KV cache for prompts, which is typically the bottleneck in memory efficiency.
Additional challenge lies in compressing KV cache for such vast prompts without losing crucial information for accurate generation, especially in scenarios with various noisy contexts.
In our paper, we find an important attention allocation phenomenon: only a portion of prompt tokens convey essential information for response generation, and these tokens remain unchanged during generation.
这里不懂他说的,这部分prompt tokens remain unchanged during generation是什么意思?什么叫这部分prompt tokens再generation的过程中保持不变?难道是说这部分tokens在generation的过程中,整个prompt里面一直都只有这部分tokens传递了主要的核心信息。
From our observations, we derive an innovative and intuitive method, SnapKV, which can smartly identify the attention allocation pattern and compress the KV cache for long sequence prompts without compromising the model’s accuracy.
We design experiments to explore the attention allocation pattern during generation, focusing on two key questions:
Is there a consistent attention allocation pattern for input sequence tokens?
Is it feasible to identify this pattern prior to the generation stage?
Our finding suggests that for LLMs, the attention allocation of most input sequence tokens stay consistent during generation. Thus, LLMs knows what you are looking for before generation.
Inspired by our observations above, we develop an efficient and fine-tuning-free algorithm, SnapKV, which efficiently identifies critical attention features and compresses KV cache correspondingly with minimal model modification (See Fig. 1).
Heavy-Hitter Oracle (H2O) introduces a policy that greedily drops KVs during generation based on a scoring function derived from cumulative attention. While this approach effectively compresses the KVs appended to the cache during generation, it overlooks compression of prompt KVs, which is crucial for reducing memory and computational overhead.
Building on a similar concept, Adaptive KV Compression (FastGen) implements a dual-phase algorithm that encompasses(包含) four KV cache compression policies. Initially, it identifies optimal policies through profiling results obtained from prompt encoding. Subsequently, it dynamically evicts caches during the generation phase based on these policies. Nonetheless, it faces the similar problem with H2O.
ScissorHands focuses on identifying and retaining pivotal(关键) tokens that exhibit a consistent attention weight pattern with previous token windows during generation steps. However, this method concentrates solely on the window of previous pivotal tokens in generation and neglects the extensive prompt that contains essential information for generating accurate responses. This oversight could lead to an inability to extract detailed information from prompts.
-Pattern can be identified before generation.
In this experiment, we split the attention features of input sequence of each layer into multiple windows, each with 128 tokens, and calculate the averaged attention weights of the last 20 windows separately. To understand the attention allocation patterns along input sequences, we calculate the overlap rates between important attention features of input sequence (those with high average attention weights) identified by each window and the actual ones used by generation. The experimental results are shown in Fig. 2.
We observe that the last window of input sequence recognizes highly similar attention allocation pattern with the actual generation.
Figure 2: The overlap rates between attention features of the input sequence, selected by various windows along the input and during generation, with each line representing a model layer.
-Pattern is consistent during generation.
We study if the positions of features identified as crucial in the last window of input sequence maintain their significance in the subsequent token generation. In the experiment, we split the generated tokens into 4 windows for every layer, each spanning 128 tokens, to compute the averaged overlap rates of these windows versus the last window of input sequence. As shown in Fig. 3, active attention features of input sequence obtained from the last window exhibit remarkable consistency throughout the generation process, as evidenced by high overlap rates.
Figure 3: The layer-wise overlap rates between
input sequence attention features selected by the
last window of input sequence and those selected
by 4 windows along generation.
In the attention mechanism, the growth in prompts will significantly increase time complexity for generation due to the Query-Key matrix multiplication. SnapKV addresses this issue by maintaining a constant amount of prompt KVs during generation, significantly reducing serving times for longcontext LLMs. To structure our method coherently, we propose the following terminologies:
Prompt Length ( ): The total length of the user-provided input.
Observation Window ( ): The last segment of the prompt. This window is crucial for analyzing the influence of different contexts on attention allocation patterns.
Prefix Length ( ): The length of the input preceding the observation window. It is part of the prompt and does not include the observation window. Overall, we have:
Voting: The process of calculating attention weights for each query within the observation window across all heads, aggregating these weights to highlight the prefix positions that are considered most significant. For a single batch of sequence, formally:
where selects the indices I of the top k values in tensor C per head. k is defined as , where p stands for the compression rate. The tensor represents the subset of the prompt softmax-normalized attention features over N heads.
Hit Rate: We define attention features above a predefined threshold during generation as important features. The hit rate, H, is the number of important features successfully selected by the previous voting process over the total number of important features. H quantifies the effectiveness of the voting mechanism and is calculated as follows:
represents the attention features between the current generated query and prefix keys. M selects attention features by indices. M selects attention features by indices. The threshold operation filters to retain only features with values over , indicating important attention activations. The measures the overlap between attention features selected by , quantifying the alignment of the current attention with previously identified important features.The hit rate H is then computed as the ratio of the sum of overlap to the sum of important features , providing a metric for the efficacy of the attention mechanism in recognizing and emphasizing important attention features within the context.We use to denote combination of the last two equation.
Overall, SnapKV operates through two stages as follows:
Vote for important previous features. By the voting process defined above (Eq. 2), we select the important attention features based on the observation window. Sec.3 highlights the consistency of the attention allocation pattern within observation windows throughout the generation, suggesting that these selected attention features are also vital for subsequent generation. Furthermore, we implement clustering to retain the features surrounding the selected attention features (Sec. 4.3). Line 8-17 shows the pseudo code of the voting process.
Update and store compressed keys and values. We concatenate the selected attention features with all features within the observation window, which encompasses all features containing the
necessary prompt information. Line 18- 24 shows the compressing process. The concatenated KVs are stored for later use in generation, thereby saving memory usage.
Listing 1: Implementation of SnapKV in pseudo PyTorch style.
第13行有个pool的操作(池化操作)我还不是很懂是干嘛的
看到后文有提到关于池化的优势:
4.3 Efficient Clustering via Pooling
In LLMs, information retrieval and generation rely on features with high attention weight and are supplemented by copying the rest of features in context using induction heads [15]. Hence, naively selecting the top features results in retaining only portions of details and then losing the completeness of the information. For example, such compression might cause the LLMs to retrieve only the country code of a phone number and hallucinate the rest. Our experiment also revealed that only selecting the features with the highest weights is insufficient (Sec. 5.2). Such sparse selection risks compromising the contextual integrity encapsulated in between features, thereby reducing accuracy. Based on the insights, we propose a fine-grained clustering algorithm utilizing a pooling layer shown in Line 13.
Overall, we want to answer the following two questions:
Does the nature of instructions in the prompt affect the hit rate?
Does the context and instruction positioning affect the hit rate?
4.2.1 Contextual Dependency of Patterns
We analyze whether instructions will affect the selection of important features even if the provided context is the same. Our experiment utilizes different instructions on the same document and selects the important features based on the observation window that consists of both the instructions and their corresponding responses. Then we calculate the hit rates between important features selected by different instruction-response pairs within the same document by using . By varying the instructions, we observe that different instructions prioritize different prefix attention features, as indicated by the descending trend in hit rates shown in Fig. 4. Our findings reveal an interesting aspect of KV cache in LLMs: the important attention features change with different instructions. This variability challenges the effectiveness of static compression methods that depend on constant weighted importance or fixed policies. Thus, the complex relationship between
context and related KV cache emphasizes the need for context-aware compression strategies and highlights the capability of SnapKV that recognizes this dynamic.
Our analysis also extends to the significance of instruction positioning on the interpretability of LLMs and their selection of important features. We calculate the average hit rate for the responses using the same observation window size as in the previous experiment. Our results shown in Fig. 5 indicate that across all three datasets, the hit rates are consistently high regardless of whether instructions are positioned before or after extensive supplementary contexts. This consistency suggests that SnapKV is able to identify attention allocation patterns regardless of the question’s positions.
这里是回答了开章的第二个问题“Does the context and instruction positioning affect the hit rate?” → No。
In LLMs, information retrieval and generation rely on features with high attention weight and are supplemented by copying the rest of features in context using induction heads [15]. Hence, naively selecting the top features results in retaining only portions of details and then losing the completeness of the information. For example, such compression might cause the LLMs to retrieve only the country code of a phone number and hallucinate the rest. Our experiment also revealed that only selecting the features with the highest weights is insufficient (Sec. 5.2). Such sparse selection risks compromising the contextual integrity encapsulated in between features, thereby reducing accuracy. Based on the insights, we propose a fine-grained clustering algorithm utilizing a pooling layer shown in Line 13.
The Needle-in-a-Haystack test challenges the model to accurately retrieve information from a specific sentence ("needle") concealed within an extensive document (the "haystack"), with the sentence placed at a random location. Typically, sentences that are inserted in the middle of prompts are harder to retrieve. To rigorously evaluate SnapKV’s capabilities, we extended the document length to 380k tokens which is the longest content that can be processed by a single A100-80GB GPU. We configured the prompt KV cache size to 1024, enabling SnapKV to select the most crucial 1024 attention features from the prompt for answer generation, with a maximum pooling kernel size of 5 and an observation window size of 16, both of which are hyperparameters that can be customized. The compelling outcomes in Fig. 6 from the Needle-in-a-Haystack test underscore SnapKV’s potential to precisely manage small details on extremely long input contexts with a 380x compression ratio.
Figure 6: Needle-in-a-Haystack test performance comparison on single A100-80GB GPU, native HuggingFace implementation with only a few lines of code changed. The x-axis denotes the length of the document (the “haystack”) from 1K to 380K tokens; the y-axis indicates the position that the “needle” (a short sentence) is located within the document. For example, 50% indicates that the needle is placed in the middle of the document. Here LWMChat with SnapKV is able to retrieve the needle correctly before 140k and with only a little accuracy drop after. Meanwhile, the original implementation encounters OOM error with 33k input tokens (white dashed line).
Our evaluation utilizes the modified LongEval-Lines benchmark, incorporating randomly generated pairs and averaged scores. LongEval-Lines presents a greater challenge compared to Needle-in-a-Haystack because it involves identifying key-value pairs in noisy contexts of the same format, while in Needle-in-a-Haystack, the relevant information is more distinctly separated from other contexts. We apply max pooling with a kernel size of 5 and use the observation window with a size of 16, which are hyperparameters and could be customized according to different models. As illustrated in our results (Fig. 8), we find that pooling significantly enhances retrieval accuracy compared to methods not utilizing pooling. We hypothesize that this is because the initial portions of critical token clusters are weighted higher by attention mechanisms. Typically, large language models tend to copy the tokens surrounding the initial portions to keep the contextual integrity. However, naively compressed KV cache breaks this mechanism and could lead to partially correct results (Fig. 8). Note that throughout our experiments, the choice between max pooling and average pooling did not yield significant differences in performance.
Figure 8: Ablation study of pooling on LongEval-Lines. The evaluation includes inputs, each comprised of lines formatted as "line makeshift-penguin: REGISTER_CONTENT is <10536>", where the key is an adjective-noun pair and the value is a random 5-digit number. The model needs
to retrieve the value based on a given key. The x-axis denotes the length of the input; the y-axis indicates the position of the ground truth, from 5K to 30K tokens. With the pooling, the model can retrieve correct values before 16k and performs significantly better than the one without pooling.
捉个虫:这里图的第二张的title打错了,应为Mistral-7B-Instruct-v0.2 with Pooling
We evaluate SnapKV on these four models using LongBench , a multi-task benchmark designed to rigorously evaluate long context understanding capabilities across various datasets, spanning single and multi-document QA, summarization, few-shot learning, synthetic tasks, and code completion. We choose LWM-Text-Chat-1M with 1 million context length, LongChat-7b-v1.5-32k, Mistral-7B-Instruct-v0.2, Mixtral-8x7B-Instruct-v0.1 with 32k context length as our baselines. For each model, we test SnapKV with various settings: compressing KV caches in the prompt to 1024, 2048, and 4096 tokens. We use max pooling with kernel size 7 and observation window size 32. Table 1 illustrates a negligible performance drop from models with SnapKV compared with original implementations for 16 different datasets, even with prompt-KV with 1024 tokens. Some models even outperform the baseline.
Table 1: Performance comparison of SnapKV and H2O across various LLMs on LongBench.
The goal of zero-shot and few-shot learning is to get a machine-learning model to perform a new task it was not trained for.
我感觉这个就是我比较在意的测试,就是给一些提示,让模型学习一些pattern或者思路,然后“创造”一些新的内容。
Is few-shot prompting the same as few-shot learning?
“Few-shot learning” and “zero-shot learning” are well-known concepts in machine learning that were studied long before LLMs appeared on the scene. In the context of LLMs, these terms are sometimes used interchangeably with “few-shot prompting” and “zero-shot prompting.” However, they are not the same.
Few-shot prompting refers to constructing a prompt consisting of a couple of examples of input-output pairs with the goal of providing an LLM with a pattern to pick up.
Few-shot learning is a model adaptation resulting from few-shot prompting, in which the model changes from being unable to solve the task to being able to solve it thanks to the provided examples.
In the context of LLMs, the “learning” is temporary and only applies to a particular chat conversation. The model’s parameters are not updated, so it doesn’t retain the knowledge or capabilities.
它后面的一些测试我没有那么关心了,于是我就暂且略过。。。
5.5 Case Study: Compatibility with Parallel Decoding
In this section, we provide a novel perspective on employing KV cache compression synergistically with parallel decoding. Parallel decoding leverages a lightweight model or an adaptor to draft initial tokens, which are subsequently verified by larger LLMs.
This strategy effectively reduces memory overhead, a critical concern given the autoregressive nature of LLMs that renders them more memory-intensive than computationally demanding. Specifically, in LLMs, each decoding step involves generating a single token, with the transfer of weights between High Bandwidth Memory (HBM) and cache contributing to significant overhead