news

Two small models verify each other and directly compare with the big model? Microsoft's rStar doesn't even use CoT

2024-08-16

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina



Machine Heart Report

Editor: Panda

Check each other so that small models can also solve big problems.

As we all know, LLM is powerful, but not powerful enough to perform complex reasoning.

For example, on the GSM8K dataset, Mistral-7B can only achieve 36.5% accuracy even with techniques such as Chain of Thoughts (CoT). Although fine-tuning can effectively improve reasoning capabilities, most LLMs rely on fine-tuning data that has been distilled from more powerful models such as GPT-4, or may even have been synthesized by these powerful models.

Meanwhile, researchers are actively developing an approach that can provide assistance but is more difficult: using a better teacher LLM to improve reasoning skills.

To improve reasoning in the absence of a better model, a promising paradigm is to exploit the knowledge in the LLM itself. For example, a method called RAP uses a self-exploration approach to iteratively improve the reasoning performance of the LLM through self-rewarding feedback. Unfortunately, research shows that this paradigm has two fundamental problems.

First, LLMs often have difficulty exploring the solution space effectively when performing reasoning. This self-exploring approach often gets stuck in a solution space due to poor quality reasoning steps, even after multiple attempts.

Second, even if self-exploration finds high-quality reasoning steps, it is difficult for a small version of the large language model (SLM) to distinguish which reasoning steps are of higher quality, and it is difficult to determine whether the final answer is correct, making it difficult to effectively guide self-exploration. Studies have shown that the results of self-exploration guidance based on basic conventional rewards are no better than random guessing.

What’s more troublesome is that small versions of large language models (SLMs) are more prone to the above two problems because they have weaker capabilities. For example, GPT-4 can improve its output results through self-optimization, but SLMs find it difficult to do so and may even cause the quality of output results to deteriorate. This will seriously hinder the promotion and application of neural language models.

To address these issues, a research team from Microsoft Research Asia and Harvard University proposed Self-play mutual reasoning, or rStar. In simple terms, this method is similar to having two mediocre students check each other's test answers, ultimately improving their scores to the level of top students. The team claims that rStar "can improve the reasoning ability of SLM without fine-tuning or better models."



  • Paper title: Mutual Reasoning Makes Smaller LLMs Stronger Problem-Solvers
  • Paper address: https://arxiv.org/pdf/2408.06195
  • Code address: https://github.com/zhentingqi/rStar (to be released)

method

In order to solve the above problems, rStar divides the reasoning process into two parts: solution generation and mutual verification, as shown in Figure 2.



For the first challenge, the team introduced a collection of rich human-like reasoning actions that thoroughly explore a variety of different reasoning task spaces.

For the second challenge, they designed a reward function specifically for SLMs, which can evaluate intermediate steps rather than relying on their often unreliable self-evaluations.

In addition, the team also used another SLM as a discriminator to enhance the MCTS process, and the discriminator SLM mutually verified the correctness of each trajectory.

Use MCTS Rollout to generate your own inference trajectory

A rich set of human-like reasoning actions. The core of MCTS generation is the action space, which defines the scope of tree exploration. Most MCTS-based methods use a single action type when building the tree. For example, the action in RAP is to ask the next sub-question, while the action in AlphaMath and MindStar is to generate the next reasoning step. However, relying on a single action type can easily lead to poor space exploration.

To address this problem, the team reviewed the way humans perform reasoning. Different people approach problems differently: some break the problem down into subproblems, others solve the problem directly, and some rephrase the problem from a different perspective. In addition, people adjust their approach based on the current state, choosing different actions as needed.

Inspired by the human reasoning process, the team constructed a richer dataset containing five categories of actions to maximize the potential of SLM to correctly solve complex reasoning problems.

Action 1: Propose a step of thinking. For a given problem, this action will allow LLM to generate the next step of thinking based on the existing reasoning steps.

Action 2: Propose the remaining thought steps. This action, like the standard CoT, enables "fast thinking" to solve simple problems that require only a small number of steps. Given the reasoning steps that have been generated, it will let the LLM directly generate the remaining steps until the final answer is obtained.

Action 3: Propose the next sub-question and its answer.

Action 4: Answer this sub-question again. Considering that action 3 may not correctly answer the corresponding sub-question, the purpose of this action is to answer it again.

Action 5: Re-state the problem/sub-problem. This new action is to reformulate the problem in a simpler way. Specifically, here we have the LLM clearly list all the conditions in the problem statement.

The above five actions define a highly diverse action space {A1, A2, A3, A4, A5}.

At each step i, MCTS selects an action a_i from this space. Then, based on the current state (i.e., the previously generated trajectory x ⊕ s_1 ⊕ s_2 ⊕ ... ⊕ s_{i−1}), the action a_i is used to let LLM generate the next reasoning step s_i. Note that some actions need to be executed in sequence. Figure 3 gives an example.



As shown in Table 1, each action plays an important role in improving the final inference accuracy.



  • Reward Function

Another key component of MCTS is the reward function, which evaluates the value of each action and provides an indication for the expansion of the tree. For SLM, the team designed a simple but effective reward function. Their approach was inspired by AlphaGo, which scores each intermediate node based on its contribution to the final correct answer. In this way, actions that frequently get the correct answer can receive higher rewards, and they are more likely to be selected in future MCTS tree expansions.

Here, the reward value of node s generated after executing action a is defined as Q (s, a). At the beginning, all unexplored nodes are assigned Q (s_i, a_i) = 0, thereby achieving random tree expansion. When reaching the first end node n_d, a reward score Q (s_d, a_d) is calculated based on whether it gets the correct answer.

This score is then back-propagated to each intermediate node along the trajectory t = x ⊕ s_1 ⊕ s_2 ⊕ ... ⊕ s_d. Specifically, for each si, its Q value is updated as follows: Q (s_i, a_i) = Q (s_i, a_i) + Q (s_d, a_d). To calculate Q (s_d, a_d) for the end node, the reward value used here is the likelihood (confidence) of the self-consistent majority vote.

  • Generate solutions using MCTS Rollout

The following describes how MCTS generates candidate reasoning trajectories. Starting from the initial root node s_0, a variety of searches including selection, expansion, simulation, and backpropagation are performed. Specifically, the simulation uses the default Rollout strategy. In order to obtain more accurate reward estimates, the team will perform multiple Rollouts. In order to balance exploration and exploitation, they used the famous UCT (upper confidence bound of the tree) to select each node. The mathematical form of this selection process is:

Where N (s, a) is the number of times node s has been visited in the previous iteration, and N_parent (s) is the number of visits to the parent node of s. Q (s, a) is the estimated reward value, which will be updated during the backpropagation process. c is a constant that balances exploration and exploitation.

Once the search reaches a terminal node (which may be a terminal state or a predefined maximum tree depth d), a trajectory from the root to the terminal node is obtained. Collect all the trajectories obtained by Rollout iteration as candidate solutions. Next, they need to be verified.

Selecting inference trajectories using mutual consistency

Based on all the collected trajectories, the team proposed using reasoning consistency to select the answer.

  • Reasoning consistency through discriminator SLM

As shown in Figure 2, in addition to the target SLM, the team also introduced a discriminator SLM, whose role is to provide external unsupervised feedback for each candidate trajectory.

Specifically, for t = x ⊕ s_1 ⊕ s_2 ⊕ ... ⊕ s_d, mask the reasoning steps starting from a randomly sampled step i. Then provide the previous reasoning trajectory t = x ⊕ s_1 ⊕ s_2 ⊕ ... ⊕ s_{i-1} as a prompt to the discriminator SLM to complete the remaining steps. Since the previous i-1 reasoning steps are used as prompts, the difficulty is reduced and the discriminator SLM is more likely to give the correct answer.

Figure 4 compares whether the answer completed by the discriminator SLM matches the original trajectory t. If the two are consistent, t is considered to be a verified trajectory that can be finally selected.



The target SLM selects the final trajectory. After reasoning about all candidate trajectories using mutual consistency, the target SLM is returned to select the final trajectory from the verified trajectories. To calculate the final score for each trajectory, the team multiplied its reward by the confidence score of its end nodes obtained through Rollout. The trajectory with the highest final score is selected as the solution.

experiment

Experimental setup

rStar is applicable to a variety of LLMs and reasoning tasks. The team evaluated 5 SLMs: Phi3-mini, LLaMA2-7B, Mistral-7B, LLaMA3-8B, LLaMA3-8B-Instruct.

There are 5 reasoning tasks tested, including 4 math tasks (GSM8K, GSM-Hard, MATH, SVAMP) and 1 common sense task (StrategyQA).

For experimental details, please visit the original paper.

Key results

The team first evaluated the effectiveness of rStar on general reasoning benchmarks. Table 2 compares the accuracy of rStar and other current best methods on different SLM and reasoning datasets. To demonstrate the effect of the new generator, the team also provides the accuracy of rStar (generator @maj), which is the accuracy obtained by not using a discriminator and only using majority voting to verify the answer.



The team noted three key results:

1. SLM with the help of rStar is more capable of solving problems. For example, on the GSM8K dataset, the accuracy of LLaMA2-7B using few-shot CoT is only 12.51%. But with the help of rStar, its accuracy is improved to 63.91%, which is close to the accuracy obtained using fine-tuning, as shown in Figure 1. Similarly, Mistral using rStar performs 4.18% better than the fine-tuned version of MetaMath. This improvement shows that SLM itself has strong reasoning ability, but needs guidance to generate and select the correct answer.



2. rStar can consistently improve the reasoning accuracy of various SLMs evaluated on different tasks to the current state-of-the-art. In contrast, none of the other comparison methods can consistently achieve good performance on all four benchmarks. For example, although SC (Self-Consistency) is good at three math tasks, it cannot effectively solve the logical reasoning task of StrategyQA.

3. Even without the proposed discriminator for validating the inference trajectory, the proposed MCTS generator still works well in improving the inference accuracy of SLM. For example, on the GSM8K dataset, rStar (generator @maj) is 2.88%-16.39% more accurate than RAP, 10.60%- 38.37% more accurate than ToT, and 1.69% - 7.34% more accurate than SC.

  • Results on difficult math datasets

The team also evaluated rStar on a more difficult math dataset. They chose the GSM-Hard and MATH datasets for this purpose. Following the convention of similar studies, they used MATH-500, a subset of representative problems from the MATH dataset. This was done to increase the evaluation speed. As shown in Tables 2 and 3, rStar was able to significantly improve the reasoning accuracy of SLM on these difficult math datasets.



Ablation studies

  • Effectiveness of different rollouts

rStar uses the Rollout strategy to perform MCTS tree expansion. More Rollouts will generate more candidate solution trajectories, but will also increase the inference cost. Figure 5 compares the accuracy of SC, RAP, and rStar using different Rollouts on GSM8K.



Two key observations are made here:

1. Even with only 2 rollouts, rStar can significantly improve the inference accuracy of SLM, which shows its effectiveness;

2. More rollouts are beneficial to both rStar and SC, while RAP tends to saturate or even decline after 4 rollouts. One reason is that the single-type action space of RAP limits the effect of MCTS exploration.

  • Effectiveness of MCTS Generator

The team compared the effects of the MCTS generator with three other generators. As shown in Table 4, the newly proposed MCTS generator outperforms the other generators across the board. In addition, the effectiveness of the reward function adjusted for SLM is also demonstrated, as self-evaluation reduces the accuracy of the new generator.



  • Effectiveness of the Discriminator

The team set up two evaluation experiments.

The first experiment is to compare the discriminant method with the majority voting and self-verification methods. The results are shown in Table 5 (left), and it can be seen that the advantage of the discriminant method is very significant.



The second experiment studies the impact of different discriminator models. The results are shown in Table 5 (right), where we can see that choosing different discriminator models generally does not affect the effectiveness of the reasoning consistency method in verifying the answer. It is worth noting that even using the powerful GPT-4 as the discriminator, the performance is only slightly improved (from 91.13% to 92.57%). This shows that the reasoning consistency method can effectively use SLM to verify the answer.