what is o1; why is it a big deal?
September 17, 2024.
I spent the last few days trying to reverse engineer the process behind o1 and I'd like to share what I think o1 is, and why it's important for the future of LLMs. I feel pretty confident to say that the people who don't understand o1 are chucking this model up to glorified CoT while the people I respect realize this as a paradigm shift. It all goes back to what Karpathy talks about in this lengthy tweet:
"RLHF is that it is just barely RL, in a way that I think is not too widely appreciated. RL is powerful. RLHF is not. AlphaGo was trained with actual RL. The computer played games of Go and trained on rollouts that maximized the reward function (winning the game), eventually surpassing the best human players at Go.
What would it look like to train AlphaGo with RLHF? Well first, you'd give human labelers two board states from Go, and ask them which one they like better. Then you'd collect say 100,000 comparisons like this, and you'd train a "Reward Model" (RM) neural network to imitate this human "vibe check" of the board state. You'd train it to agree with the human judgement on average. Once we have a Reward Model vibe check, you run RL with respect to it, learning to play the moves that lead to good vibes. Clearly, this would not have led anywhere too interesting in Go. There are two fundamental, separate reasons for this:
- The vibes could be misleading - this is not the actual reward (winning the game). This is a crappy proxy objective.
- You'd find that your RL optimization goes off rails as it quickly discovers board states that are adversarial examples to the Reward Model. Remember the RM is a massive neural net with billions of parameters imitating the vibe. There are board states are "out of distribution" to its training data, which are not actually good states, yet by chance they get a very high reward from the RM.
you can't even run RLHF for too long because your model quickly learns to respond in ways that game the reward model."
The key here is that RLHF is trained on a proxy objective of human preference rather than the "actual" objective of solving problems. Real RL is the reason why AlphaGo is able to become superhuman, and that's what we want, how do we combine real RL with LLMs such that we can propel models into superhuman reasoning?
o1 is by no means solution to this question, but it's the first production grade model trained with real RL "Our large-scale reinforcement learning algorithm teaches the model how to think productively using its chain of thought in a highly data-efficient training process." This is a direction of immense expected reward (I also suspect there is a much stronger version of o1 behind the scenes that is simply to expensive to serve).
Inference
An effect of o1's thinking is evident when we look at the distribution of compute over the model's lifecycle. Traditionally, large language models (LLMs) front-load compute into pre-training. In contrast, o1 pushes much of this compute into inference time, allowing the model to search for answers in real-time, leveraging its reasoning capabilities on the fly.
This shift makes sense: many parameters in large models are used for memorization, but reasoning doesn’t necessarily require a lot of parameters. With o1, pre-training is no longer the bottleneck for reasoning tasks. Instead, compute is allocated dynamically during inference, depending on the complexity of the problem. In traditional models, the same amount of compute is used regardless of problem difficulty, which is suboptimal. In theory, harder problems should demand more compute, and o1 aims to address this via search.
Jim fan discusses the above in a great twitter thread. He's also got a nice visualization of the compute distribution which I've replicated (but adjusted the post-training stage which was too large):
We see that o1 has a significant portion of it's compute dedicated to inference time; we don't have to speculate about this part, oai has rate limited o1 harder than any other model, and the output tokens cost about 6x that of 4o.
Please note however that this shift has been advised in numerous papers (and been implemented) before o1, and if this was all o1 was doing I wouldn't be writing this post. For example, in tackling the ARC Challenge, LLMs have been hooked up to Python interpreters to perform discrete program searches. In a way, this can be seen as an inference-time Monte Carlo Tree Search (MCTS), where the model actively searches for a solution.
However, one of the main challenges in reasoning settings is understanding when to stop searching. In games, we have a clear, discrete reward function (e.g., win or lose), but when dealing with arbitrary problems, how does the model know when it has found the right answer? This is a key question for future developments in search-based inference.
Post-Training: Reinforcement Learning Meets LLMs
While the shift in inference compute is interesting, the real innovation in o1 is happening in the post-training stage. This is where things get particularly exciting. It seems likely that o1 introduces a new step in LLM training—one that combines large-scale next-token prediction with reinforcement learning (RL), MCTS, and (possibly) self-play.
Our large-scale reinforcement learning algorithm teaches the model how to think productively using its chain of thought in a highly data-efficient training process. We have found that the performance of o1 consistently improves with more reinforcement learning (train-time compute) and with more time spent thinking (test-time compute). The constraints on scaling this approach differ substantially from those of LLM pretraining, and we are continuing to investigate them.
This would involve dedicating an entire training stage to generating MCTS-like reasoning traces. These traces would then be fed back into the model as gradients, reinforcing it to reason more effectively. In essence, the model would be learning not just from static datasets but from its own reasoning processes, much like AlphaZero's self-play training loop. How exactly they've managed this, or to what extent it works, we don't know, but I think these papers hold some part of the solution:
Scaling LLM Test-Time Compute Optimally can be More Effective than Scaling Model Parameters
A paper that investigate how to scale test time compute, focusing on two main approaches of scaling:
-
Search at inference: They compared different search algorithms against Process Reward Models (PRMs):
- Best-of-N: Sampling N independent outputs and selecting the best using a verifier.
- Beam search: Maintaining a fixed number of candidate solutions, expanding and pruning based on PRM scores at each step.
- Lookahead search: An extension of beam search that uses k-step rollouts to improve step evaluation accuracy. Similar to MCTS
-
Refining the proposal distribution: They fine-tuned models to iteratively revise their answers, analyzing:
- Sequential revisions: The model generates a sequence of revisions, each conditioned on previous attempts.
- Parallel sampling: Generating multiple independent solutions simultaneously.
- Combinations of sequential and parallel approaches with varying ratios.
They found that the effectiveness of these methods varies depending on problem difficulty and compute budget. Easier questions often benefited more from sequential revisions, while harder questions required a balance between sequential and parallel computation. Beam search showed advantages over Best-of-N at lower compute budgets but diminishing returns at higher budgets.
Given that the efficacy of a given approach heavily correlated with the difficulty of the problem from the perspective of the base LLM's capabilities, the authors introduce the notion of "compute-optimal" scaling, where test-time compute is adaptively allocated based on estimated question difficulty, improving efficiency compared to baseline methods. By applying such a compute-optimal scaling strategy, they improve the efficiency of test-time compute scaling by a factor of 2 − 4x.
Importance for o1. RLHF assigns a binary reward to the whole trajectory, meaning the model doesn't learn where it went wrong along the way. Process Reward Models (PRMs) are a crucial key to combining RL and LLMs in that we can create per-step rewards that combine with an outcome-based reward model or heuristic that tells the system it got the answer right (and probably a length penalty, so it doesn’t generate non-answers forever to avoid a negative return), the system will have a per reasoning step reward assigning credit towards a final answer.
REFT: Reasoning with REinforced Fine-Tuning
ReFT is a two-stage approach to fine-tune large language models for math problem-solving. The first stage is a warm-up phase using standard supervised fine-tuning (SFT) on a dataset of questions paired with Chain-of-Thought (CoT) annotations. This equips the model with basic problem-solving skills.
The second stage employs online reinforcement learning, specifically the Proximal Policy Optimization (PPO) algorithm. In this phase, the model learns by repeatedly sampling responses to questions, evaluating the correctness of the generated answers, and updating its parameters. The reward function is based on the correctness of the final answer, with partial rewards possible for numeric answers. Importantly, ReFT uses the same training questions as SFT but allows the model to explore multiple reasoning paths per question. To prevent the policy from diverging too far from the initial model, a KL divergence term is added to the reward function. This approach enables the model to learn from both correct and incorrect attempts, potentially improving its generalization ability without requiring additional training data or annotations.
The experiments demonstrated that ReFT consistently outperformed supervised fine-tuning (SFT) across multiple datasets and model architectures. Notably, on the GSM8K dataset using CodeLLAMA, ReFT achieved up to 12 percentage points improvement over SFT. This performance gain was achieved without using any additional training data, suggesting that ReFT's exploration of multiple reasoning paths per question leads to better generalization.
Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking
The Quiet-STaR paper builds on the Self-Taught Reasoner (STaR) framework, which aimed to bootstrap language models' reasoning abilities. In STaR, a model learns reasoning by generating rationales for few-shot question-answering tasks and reinforcing those that lead to correct answers. While effective, this method was limited to predefined, curated reasoning tasks. Quiet-STaR expands this idea by enabling the model to infer rationales from any text, generalizing reasoning to more diverse, unstructured data.
These rationales act as internal thoughts that help the model reason beyond token-by-token prediction; enabling the model to silently think through the implications of text.
The process involves three core steps:
- Thinking. Using a parallel sampling algorithm, Quiet-STaR generates rationales at each token in a sequence to improve future predictions.
- Talking. The post-rationale next-token logits are mixed (weighted) with the base model logits based on a shallow MLP to ease distribution shifts in early fine-tuning stages
- Learning. Learning to generate better rationales using REINFORCE; optimizing the rationale generation parameters to increase the likelihood of rationales that make future text more probable.
Now for why OAI succeeded, i think it just comes back to scale. They've realized this by now, they don't get bogged down in the details, they build shit that scales. While these papers all do a great job at providing possible ways to reinforce reasoning, OAI manages to abuse scale. I suspect their RL algo to be fairly simple in nature.
Anyway, what I want you to see is that o1 (and future reasoning agents) probably has a compute distribution closer to something like this
with an immense amount of resources being pulled into this real RL stage, shifting compute from the now, overparametrized pre-training stage.
We see signs of this in o1-mini that still manages to perform well on reasoning tasks. This is likely because a significant portion of the compute in o1 is dedicated to this synthetic data generation and self-improvement step. Models like this will be combined into a larger system of models where requests are routed based on their needs. Obviously OAI doesn't want o1 spending 200 tokens thinking about questions like "How many countries are there in Europe?".