The Quest for Q*: Leveraging Q-Learning to Enhance LLMs
Introduction:
In the rapidly evolving world of artificial intelligence, Large Language Models (LLMs) have emerged as a game-changer, capable of tackling a wide array of tasks with unprecedented efficiency. The secret behind their success lies in the powerful combination of LLMs and Reinforcement Learning (RL), a machine learning technique that enables models to learn from their actions and optimize their performance over time.
One of the most intriguing developments in this field surfaced last year during the OpenAI fiasco, when rumors of a groundbreaking algorithm called Q* began to circulate. This mysterious algorithm was said to possess the ability to solve grade-level math problems, even those it had never encountered before. Despite the buzz, OpenAI remained tight-lipped about the details of Q*, leaving the AI community to speculate about its inner workings.
Background
For those unfamiliar with the concept, Q* is believed to be rooted in Q-Learning, a type of Reinforcement Learning algorithm. In Q-Learning, the “Q-value” represents the expected future rewards for taking a specific action in a given state. The optimal Q-value, denoted as Q*, is the maximum expected future rewards achievable by selecting the best action in that state.
While the specifics of Q* remain shrouded in secrecy, some insightful analyses have emerged online, such as this interesting YouTube video. These discussions often point back to the principles of Reinforcement Learning and Q-Learning as the foundation for Q*’s potential.
One notable paper that sheds light on the possible mechanics of Q* is “Let’s Verify Step by Step” by Lightman et al. (2023). This research introduces a verifier system that works alongside the LLM generator to validate the answers produced by the model. The verifier learns to assign labels (positive, negative, or neutral) to each step in solving a math problem, effectively breaking down the problem into a sequence of steps and assessing the correctness of each one. The verifier is trained using human-labeled data from the PRM800K dataset.
Combining these insights, it’s conceivable that the Q* algorithm is a Q-Learning approach that learns the optimal Q-value for each step in a math problem. The Q-value, in this case, represents the expected future rewards (i.e., the correctness of the step) for taking a particular action. If the verifier can accurately determine the correctness of each step, the LLM can generate solutions based on the most correct steps, leading to more accurate answers overall.
As we delve deeper into this exciting project, my goal is to investigate the use of Reinforcement Learning to fine-tune LLMs and explore the potential of RL in the context of language models.
Methods
A notable recent advancement comes from a paper titled “Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations” by Wang et al. (2024), which builds on earlier work by Lightman et al. (2023). This new method offers a promising approach to enhance datasets, making them more effective for training sophisticated models.
While Wang et al. have been more transparent than some other groups in this space, such as OpenAI, they still hold back full details and code. Nevertheless, what they have shared could significantly improve how we train language models.
I will delve deeper into the specifics of the verifier method and share my personal experience implementing it, comparing it with traditional approaches. I’ve experimented with different base models and configurations, facing challenges and discovering nuances that could influence outcomes significantly. By comparing different methods and models, such as TinyLlama 1.1B and Llama3 8B, I aim to illustrate the practical impacts of these techniques on model training and performance.
Improving the Dataset and objectives:
The Math-Shepherd paper use additional annotated data, so I will use that instead of PRM800K from OpenAI’s release. The paper suggests appending each step label with additional tokens to represent the correctness of the step instead of direct classification. For example, they append ‘+’ for correct and ‘-’ for incorrect steps. The authors claim that using a binary representation (correct/incorrect) of the step label has similar performance compared to OpenAI’s three-label representation (correct/neutral/incorrect). OpenAI’s paper essentially combines the neutral and correct labels, treating them as correct, which supports the use of a binary representation. This method simplifies the complexity seen in OpenAI’s three-label system.
The dataset released by the authors on Huggingface shows that they use ‘ки’ as the step tag in the input representation, which is used to predict either ‘+’ or ‘-’. This approach allows the entire training process to fit into a regular language model training framework, which is more efficient than using a classification head. The classification head approach requires predicting n times for a solution with n steps, while the language model only needs a single forward pass to calculate the loss. From my experience of training both, the language modeling approach is more efficient and simple.
Base Reward Models:
I experimented with both TinyLLaMA 1.1B and LLaMA-3 8B as base reward models. The LLaMA-3 8B model performed better than TinyLLaMA 1.1B. However, special attention must be paid to the tokenizer during training, as the presence or absence of adjacent tokens and spaces can significantly affect tokenization and the model’s learning ability. Adding a space might result tokenizer to represent tokens in different numbers (such as ‘+’ and ‘ +’), and the model might have difficulty to learn due to the tokenizer issue. The training can be performed on an RTX 4090 with 24GB of VRAM using unsloth. Despite this, the fine-tuned LLaMA-3 8B model did not perform as well as the fine-tuned Mistral 7B model from the Math-Shepherd paper, possibly due to the LoRA fine-tuning on a large distribution shift.
Verifier Model Performance:
I compared the verifier model’s performance using a TinyLLaMA 1.1B generator (LoRA fine-tuned on the GSM8K dataset) and a LLaMA-3 8B instruct generator. The baseline comparison method was Majority Vote. To evaluate the verifier model’s effectiveness, the approach from OpenAI’s paper was used, which calculates step scores and generates a solution-level score, selecting the solution with the highest score (indicating the least errors). The results showed that the LLaMA-3 8B verifier consistently improved TinyLLaMA’s results, but not the LLaMA-3 8B instruct generator’s results. This might be due to the verifier not being well-trained or the LLaMA-3 8B instruct model being too well-trained using Meta’s massive resources. They did instruction fine tuning with a combination of supervised fine-tuning (SFT), rejection sampling, proximal policy optimization (PPO), and direct preference optimization (DPO). It has shown from the results that it improved the performance of the model significantly.
The result for the TinyLlama 1.1B is presented as following graph:
The result for the Llama3 8B is presented as following graph:
In our experiments, we observed consistent, albeit modest, improvements in the performance of the TinyLlama 1.1B model when using the Llama3 8B verifier. This suggests that the verifier can effectively enhance the model’s accuracy, particularly in scenarios where the baseline capabilities of the model are less advanced.
Conversely, the Llama3 8B Instruct generator did not show similar improvements when paired with the same verifier. It is not far from just randomly select the results of a single sample. This discrepancy raises intriguing questions about the interplay between model complexity and verification techniques. One possible explanation for this outcome is that the Llama3 8B Instruct, having been developed with substantial resources from Meta, is already highly optimized for certain tasks. Its sophisticated training regimen may have saturated its capacity for improvement from external verifiers, which could be less effective if they are not equally advanced compared to generator.
Guiding Generator Model Training with the Verifier Model
In addition to training the verifier model, I decided to utilize it to guide the training of the generator model. The verifier model can provide rewards for the generator model, helping to improve its performance. In this setting, for each epoch, I roll out the generator to collect two solutions for each answer and collect the responses. The verifier model then evaluates the responses and sets one to accept and one to reject. The generator model is subsequently trained using the responses with Direct Policy Optimization (DPO). I discovered that if I do not aggregate the responses, the loss would vary significantly across different epochs. To address this issue, I decided to follow the Dataset Aggregation (DAgger) approach and aggregate all past responses to train the generator model. This aggregation helps stabilize the training process and leads to more consistent results.
By incorporating the verifier model’s feedback and employing the DAgger technique, the generator model can learn from its past experiences and progressively improve its performance. This approach leverages the verifier model’s knowledge to guide the generator model’s training, resulting in a more effective and efficient learning process.
However, due to time constraints, I was unable to fully explore this approach and its potential impact on the generator model’s performance. And I am still training. Future updates could delve deeper into this area and investigate the benefits of using the verifier model to guide the generator model’s training.
Conclusion
In conclusion, the verifier model plays a crucial role in enhancing the performance of Large Language Models (LLMs) by providing feedback on the correctness of each step in the reasoning process. Although the LLaMA-3 8B verifier model is not yet well-trained, it shows potential in improving the performance of the TinyLLaMA 1.1B generator model. By guiding the generator model’s training with the verifier model’s feedback, we can optimize the learning process and achieve more accurate results in theory.
However, in this case, we are only implementing an extremely simplified version of Q-learning with the verifier model. The model only provides the correctness of each step given the action and state, but it does not suggest the best action to choose. To fully leverage the potential of Q-learning, we need a more complex model that can provide the best action to take given the state.
In summary, while the current implementation of the verifier model shows promise, there is still room for improvement. By developing a more complex model that can suggest the best actions and leveraging temporal differences, we can further enhance the performance of LLMs in solving math problems.