news

Apple makes the big model lazy: spit out the first token faster, while maintaining accuracy

2024-08-02

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina



Machine Heart Report

Synced Editorial Department

Being lazy can help you work better.

Llama 3.1 has just been released. Have you tried it yet? Even if your personal computer is the latest top configuration, running the smallest 8B version may still cause significant delays. In order to improve the inference efficiency of the model, researchers have come up with a variety of methods, but many of them will sacrifice some accuracy of the model.

Recently, a research team from Apple and Meta AI proposed a new method that can increase the inference speed of the Llama 2 pre-filling phase by more than 2 times while ensuring that the accuracy does not drop significantly, which may provide some inspiration for the acceleration of Llama 3.1. They called this method LazyLLM, or lazy large language model.



Paper title: LazyLLM: Dynamic Token Pruning for Efficient Long Context LLM Inference

Paper address: https://arxiv.org/abs/2407.14057

So how do they make LLM lazy? To understand their approach, we first need to know what the standard prompt-based LLM reasoning process is like. In simple terms, the process is divided into two stages: pre-filling and decoding, as shown in Figure 1.



In the pre-fill phase, the model calculates and saves the KV cache of each token in the prompt and predicts the first token. We call the time spent in the pre-fill phase “Time to First Token (TTFT)”.

The pre-filling phase is followed by the decoding phase, in which the model again uses the cached KV to iteratively decode the next token until the stopping criterion is met.

During the pre-fill phase, all Transformer layers use all tokens in the prompt. TTFT can be slow when the prompt is long because the current best Transformer-based LLMs are both deep and wide, and the cost of computing attention grows quadratically with the number of tokens in the prompt. For example, Llama 2 (7B version) stacks 32 layers of Transformers and has a model dimension of 4096. In this case, TTFT requires 21 times the walltime of each subsequent decoding step, which accounts for about 23% of the total generation time on the LongBench benchmark.

Therefore, optimizing TTFT is a critical step to make LLM inference efficient.

Although LLM inference optimization is an active research area, many methods focus on improving the inference speed of the decoding stage. Researchers rarely pay attention to improving TTFT. Some compression-based research results can implicitly improve TTFT by reducing the size of LLM.

Another research direction is to improve TTFT under the static Transformer architecture. For this research direction, a natural question arises: are all prompt tokens essential when generating the first token?

Figure 2 shows the LLM analysis results on the LongBench benchmark.



It can be seen that for the first generated token, the attention score of the input token is very sparse, which means that many tokens in the input prompt are redundant and will not affect the prediction of the next token even if they are removed. This observation is the basis for the team to propose LazyLLM.

The advantages of LazyLLM include wide applicability, no training required, and good results. Figure 3 compares the standard LLM with LazyLLM.



LazyLLM

Figure 4 shows the overall framework of LazyLLM.



Starting from the full context, LazyLLM gradually prunes tokens, thereby gradually reducing the amount of computation used to obtain the final model. Note that LazyLLM allows the model to select different subsets of tokens at different generation steps, even if some of them may have been pruned in previous steps. Compared to static pruning (pruning all tokens at once), dynamic pruning optimizes the next token prediction at each generation step, which helps maintain the performance of the model.

Progressive token pruning

Some previous studies have successfully used token pruning to optimize LLM reasoning. However, these methods need to accumulate the complete attention graph of the first few tokens to analyze the importance of the prompt token before pruning begins. Therefore, they are not suitable for reducing TTFT because they still need to calculate all KV caches in the pre-filling phase.

In contrast, LazyLLM is “very lazy” and starts from the first iteration of reasoning (the pre-filling step) and only calculates tokens that are important for predicting the next token.

In the first iteration, a key challenge was determining the importance of each token. Inspired by previous research showing that token hidden states evolve as they pass through Transformer layers, the team’s solution was to use layer-by-layer token pruning at each generation step. Specifically, they used the attention maps of each layer to determine the importance of the input token to the token being predicted.

After calculating the confidence score of the token, another difficult problem is to determine the threshold for pruning tokens.

Specifically, for different layers and different tasks, the threshold may change with the change of attention score. The team's solution is to use the top-k percentile selection strategy. Specifically, if the confidence score of a token is less than the kth percentile in the input token, it will be pruned. Once the token is pruned, it will no longer participate in the calculation of all subsequent layers.

That is, the tokens used by subsequent layers are a subset of the tokens used by previous layers.

Subsequent experiments show that the performance will also change when the position of the pruned layer and the number of pruned tokens are different. Specifically, for the same Transformer layer, as more and more tokens are pruned, the performance of the model will gradually decrease.

They also found that pruning in later layers yielded better performance than pruning in earlier layers, suggesting that later layers are less sensitive to token pruning. To better balance speed and accuracy, the team used a progressive pruning method as shown in Figure 4, which retains more tokens in early layers and then gradually reduces the number of tokens as they flow to later layers.

Aux Cache

There is no KV cache in the pre-filling phase, and each token is represented as a hidden state. Therefore, progressive token pruning can be achieved by removing the hidden states of pruned tokens. However, it is not easy to extend progressive token pruning to subsequent decoding steps. The reason is that each decoding step uses the KV cache calculated in the pre-filling phase to calculate attention. Since LazyLLM performs progressive token pruning in the pre-filling phase, the KV of the pruned tokens in one layer will not appear in the KV cache of the next layer.

As a reminder, the LazyLLM framework allows each generation step to pick a different subset of tokens from the full input token sequence at each step, regardless of whether they have been pruned in the previous step. For example, in the following decoding step, pruned tokens that are not present in the KV cache may be re-selected for attention calculation. In this case, the model cannot retrieve the KV cache for these tokens.

An intuitive solution is to pass these tokens through the starting point of the Transformer again. However, this will lead to repeated calculations of the same token and ultimately slow down the overall generation speed.

To solve this problem, the team introduced another cache in addition to the original KV cache: Auxiliary Cache.

If the KVs of pruned tokens (such as T4 and T7 in Figure 4) do not appear in the KV cache of subsequent layers, their hidden states will be saved by the Aux Cache for subsequent iteration retrieval.

As shown in Figure 4, at each decoding step, each Transformer layer first retrieves the KV cache of past tokens (if any). For those tokens that are not in the KV cache, their hidden states are directly retrieved from the Aux Cache of its previous layer without having to go through the previous layer again. The Aux Cache ensures that each token is calculated at most once in each Transformer layer, and also ensures that LazyLLM is faster than the standard LLM at its slowest.

experiment

The team tested this new “lazy” approach on two large language models: Llama 2 7B and XGen 7B. The standard LLM used for comparison was the same publicly available pre-trained checkpoint model without any additional training.

The experimental benchmark is LongBench, a multi-task benchmark for long content understanding. The LongBench benchmark contains 16 datasets covering 6 tasks, including single-document question answering, multi-document question answering, summarization, few-shot learning, synthesis tasks, and code completion.

The evaluation metrics are the effectiveness and efficiency of each method in terms of TTFT speedup and accuracy trade-off.

result

Table 1 gives the TTFT speedup and accuracy results of LazyLLM, standard LLM and other baseline methods.



In this table, baseline refers to standard LLM inference. Random token drop refers to random pruning of tokens. Static token pruning refers to one-time pruning of input tokens based on the attention method of the previous Transformer layers during the pre-filling phase. Prompt Compression is the prompt compression method, which uses LLM to remove redundancy in the input context.

As can be seen from Table 1, LazyLLM is superior in terms of TTFT acceleration, while the decrease in accuracy is basically negligible. It should be noted that using LLM to compress prompts requires a lot of computation. Therefore, even though Prompt Compression can make inference faster, its actual TTFT is longer than that of standard LLM.

Effect on overall spawn rate

To evaluate the impact of the new approach on the overall generation speed, the team analyzed the percentage of prompt tokens used in the calculation and the generation speedup, see Table 2.



As you can see, the percentage of tokens used by LazyLLM calculations is always less than 100%, which means that LazyLLM does not use up all the tokens in the prompt at the end of the generation, but in theory the model can use all the tokens. This can provide additional acceleration for the overall generation process of different tasks.

Discard rate at different layers

The team also analyzed the impact of the location of the pruning layer and the number of pruned tokens. The results are shown in Figure 6.



We can see that when pruning is performed at the same Transformer layer, the fewer tokens are left, the worse the model performance is. This is consistent with our intuition. In addition, pruning at later layers will yield better performance than pruning at earlier Transformer layers, which suggests that later layers are less sensitive to token pruning.

Based on these observations, it can be said that the effectiveness of progressive token pruning has been proven.

Progressive KV growth

Finally, the team also tried to understand the internals of the model using token pruning logic. Specifically, they wanted to understand the cumulative usage ratio of prompt tokens and the corresponding unused ratio. This "cumulative token usage" can be equivalently defined as the KV cache size at each step. Figure 7 shows these cumulative prompt token usage at each stage of LazyLLM.



This result supports the hypothesis that many tokens are never selected by the model (even though theoretically the model could use all tokens in the prompt.

Considering that the model can still maintain the accuracy of performing the task, it can be concluded that the model can effectively discard tokens that do not affect the output quality.