How OpenAI's o1 changes the LLM training picture - Part 2
Read part 1: How OpenAI's o1 changes the LLM training picture - Part 1
In part 1, we covered essential LLM concepts and training and built intuition on how problem solving likely requires capabilities different from known methods. In part 2, we'll speculate on the path to o1, summarizing existing methods and drawing from experts. Let’s go!
Let’s “Go”
Before we get to o1, let’s look at a modern RL example that can build our intuition further: AlphaGo. If AlphaGo had been trained like the recent generation of LLMs, training it would have looked like showing it many games from humans and saying “From each board position, determine the move that a human would have made using these as examples.” If you’ve ever had to learn to play a tough strategy game like Go or chess, you probably know that merely watching a bunch of experts play will not make you an expert yourself (though it certainly helps build intuition).
In its early stages, AlphaGo was indeed trained to mimic human moves in historical games, which bootstrapped the model's base skill. However, the main training involved components familiar from the previous post:
- “value network”: a neural network to predict the game winner from a given board state. Note that this is different from the reward function in that it translates roughly to an expected eventual reward from the current state.
- reward function: generate reward for winning, penalty for losing
- “policy network”: a separate network used to aid move selection
- Game tree search: a class of algorithms for exploring the “game tree” by “playing out” (aka “rollout”) various versions of how the game could go, given the current board. In the case of AlphaGo, the tree search was Monte Carlo Tree Search (MCTS). MCTS doesn’t play out every version of how the game could go but rather leverages some auxiliary guiding heuristics to identify the most important moves to examine. For AlphaGo these guiding heuristics included the value and policy networks.
Armed with these and the human bootstrapped proficiency, AlphaGo was trained by letting it play against itself repeatedly to improve the learned value and policy networks. During actual games, MCTS is used with the learned networks to guide its moves.
The result was the first program to consistently defeat human professionals at Go. Though its play style was similar to humans, it differed in important ways, even making "creative" moves that surprised experts but were judged important to victory.
Play it out?
So could a similar approach be used to train a problem solving language model? Let’s look at a naive version of what that would be, and then examine the challenges. Here’s a candidate “recipe”:
- Bootstrap basic ability from human text examples, as AlphaGo mimicked human play
- Treat legal tokens as "moves", use LLM as a "policy network" to sample promising next tokens
- Use the LLM as the “policy network” which helps sample promising/important next tokens to explore
- Use an auxiliary "value network" to estimate the likelihood of reaching a correct answer (see Ⓠ footnote)
- Use a tree search algorithm, the policy network, and the value network to search for chain-of-thought rollouts that lead to a correct answer.
- Use another LLM (likely armed with correct answers) as a reward model to provide rewards/penalties
- Iteratively learn better value and policy networks as you go.
Great! We’re done, let’s collect our multi-billion-dollar valuation now! Except… there are problems with this. At each step in the tree search, the number of “branches” in the tree from that point depends on the number of legal moves available. In chess, this branching factor is on average estimated to be around 31 while in Go it’s around 250 (this is one of the reasons Go took much longer to “solve” with AI than Chess). For a modern LLM, the number of tokens in its “vocabulary” is much, much larger (for Llama 3, around 128k). Though even if Go’s MCTS doesn’t explore every move at every level of the tree, the branching factor is still important in determining the scale of the required search space. And an LLM CoT rollout done in this way would be much, much wider.
We talked about tree breadth — what about tree depth? Using the “decode the ciphertext” example from the o1 release post, the chain-of-thought given is about 5.2k tokens using the GPT-4o tokenizer. The final search tree for an LLM using this approach would therefore be at least 5.2k tokens deep. A typical Go game between professionals is about 150 moves from beginning to end. So this search tree is also much deeper. And while the LLM doesn't need to explore the full breadth of the search tree, it would need to reach the required depth to complete a valid chain-of-thought.
Bigger Ideas
If the above is not tractable, could it still be on the right track? Possibly. Let’s look at how by considering a pure prompting + sampling approach to enhancing LLM’s problem solving abilities: “Tree-of-Thoughts” (ToT). In this approach, the LLM would generate suggestions for the next problem solving step, rather than treating each valid token as a possible "next node" in the search tree. For example, if the problem was the game of 24, the LLM (when acting like a policy) might be asked to look at a partial solution and propose some possible next options. The LLM (when acting like a value model) might in this case be asked whether a given state might lead to a solution. Meanwhile, a tree search algorithm (in the original ToT paper, depth or breadth-first searches) would be used to guide the process of evaluating proposals and values along the paths through the tree.
In this case, the nodes in our tree aren’t single tokens: they’re “thoughts” that represent a larger step in the problem solving context. This ToT approach doesn’t require ANY additional learning—the LLM is used as is and plays “double duty” as value and policy models. It does, however, still require a tree search to guide the process. The fact that there is no additional learning process for ToT is also a disadvantage: the model is not able to leverage training in this context to get better performance. Another major drawback is that the approach requires specialized prompts and parsing logic for different problems, preventing it from generalizing as a general purpose problem solving solution without problem specific support.
Let’s “O” (1)
While we don’t know exactly how O1 was trained, we now have enough in hand to suggest a promising possibility:
- Start with a base LLM (e.g. GPT-4o/GPT-4o-mini), possibly primed with problem solving transcripts, to bootstrap rudimentary “problem solving through mimicry” ability.
- Use that LLM as a policy network to generate short “reasoning steps” in something like a “tree-of-thoughts”
- Use a model to learn a value function (or something playing a similar role, such as a Q function). This might be an entirely separate model or may be the same LLM invoked differently
- Use a tree search algorithm (ex: MCTS, A*) to explore the space of possible chain-of-thoughts
- Use problem/answer pairs (WITHOUT specifying intermediate steps) to have the “LLM as policy network” and value function improve themselves through RL training, aiming to solve the pairs.
A big picture illustration of what problem solving may look like at train time (possibly also test time). The “chain” within this tree that’s ultimately selected would be obtained by reading along the heavy black branches.
Can we have our multi-billion-dollar valuation now? Not quite. There are still A LOT of unanswered questions here, even assuming this is on the right track:
- What is required to bootstrap this procedure effectively, so the starting point policy LLM is competent enough to start solving meaningful problems?
- What’s the right “thought size” to treat as a node in the tree? Does it differ much from problem-to-problem? How is that managed?
- How do you keep the branching factor under control, especially during early training stages?
- Do you need any additional special procedures to keep the model from going “out-of-distribution” and collapsing?
- Do you learn a value function or something else (ex: Q-function)?
- Do you leverage rewards for valid process, or only for reaching the final answer?
- What tree search algorithm is most effective?
- How, specifically, do you construct the loss functions for the policy/value LLMs?
- How much compute does all this require?
Despite these open questions, we are not the first to speculate along these broad outlines.
End game
Despite all the uncertainty and unknowns, there are a few inferences we can draw:
- If "true" RL has been achieved, model improvement becomes more "compute constrained" than "data constrained": the system can continue improving through more self-play on the problems (AlphaGo’s successor was eventually able to beat humans without any human game data to bootstrap it, though the same is likely not possible for “reasoning models”). The release post indeed says training uses RL and is “highly data-efficient”
- Given that full RL approaches would allow the model to “explore” without constraining it to human examples, it is conceivable that it might develop problem solving strategies not yet used by humans.
- Training likely requires a ton of inference, since you need to let the model generate many partial “thought trees” per problem.
- Post-training Inference might also require generating “trees” of thoughts before discovering the final “chain” that leads to the correct answer (see 🧪 footnote).
- If thought trees are used at inference time, the search algorithm that guides the generation can be parameterized to allow for more time to explore the tree; possibly allowing some control over a cost/quality tradeoff
- Regardless, available observation data shows that the hidden chains themselves can be quite long, meaning more inference-time compute
Perhaps the most important takeaway from these speculations and corollaries is a plot that OpenAI released with their o1 announcement. Via their RL-based methods, they have been able to achieve a system where you can just continue spending on inference “indefinitely” to get improved results, which as the SCoRe paper points out is a nut that’s mostly uncracked elsewhere. As Dr. Jim Fan points out, this creates a “flywheel” where better inferences lead to better training, which leads to better training…
The “more compute is all you need” camp may or may not win the day in the end, but O1 so far appears to be a solid point in their favor. There’s certainly still plenty of room to go. At the very least, it’s clear that o1 represents something new.
Footnotes
Ⓠ: An early alias of o1/Strawberry was “Q*”. Among some, this led to speculation that the model might leverage a learned Q function, which plays a role similar to the learned value function in AlphaGo. The main difference is that the Q-function learns to predict the expected future reward for (current state, next move) pairs, whereas the value function just predicts the expected future reward for the current state. For our high-level purposes here, we can ignore this distinction. Similarly, the Q* name was suggestive that the tree search algorithm in use might be A*, especially considering that A* search + learned Q has already been referred to as Q*. Again, at the above level, the distinction isn’t too important.
🧪: Following an AMA with OpenAI, someone stated they stated: “o1 is not a ‘system’; it's a model trained to generate long chains of thought before returning a final answer”. One might read into this that there is no search algorithm used in conjunction with the model after train time. This tweet also suggests they may not be using the tree search at inference time. In any case, there is some ambiguity around this point. I am not aware of any statements from OpenAI about the usage of tree searches during training, though most RL algorithms do involve some tree search.
A comprehensive AI platform
Dataset Curation
Generate high-quality datasets.
LLM Fine-Tuning
Customize LLMs to your specific use case.
LLM Playground
Vibe-check 30+ SOTA LLMs at once.
LLM Evaluation
Compare LLMs on your entire eval set.