How OpenAI's o1 changes the LLM training picture - Part 1

How OpenAI's o1 changes the LLM training picture - Part 1

Josh Bauer·9/27/2024

If you're familiar with "Chain of Thought" prompting and know RLHF means "Reinforcement Learning with Human Feedback", but are unclear on modern language model training, you might wonder what sets OPenAI's new o1 models apart. Why does it excel on certain benchmarks? Can't we just ask existing models to explain their reasoning? Doesn't RLHF already involve Reinforcement Learning? If these questions resonate with you, you're in the right place!

This is part one of a two-part series. We'll cover essential background and discuss the shortcomings in existing LLM training that o1 aims to address. Part two will explore known and speculative details about o1's solutions.

Essential Background

Before we get into o1, it’s first important to have some baseline understanding of some elements of how LLMs work and how they are trained.

LLM Basic Principles

LLMs generate outputs token by token, with tokens ranging from word fragments to short phrases. This process, called autoregression, works as follows:

  1. The model takes prefix text and produces a probability distribution for the next token.
  2. The invoking software samples the next token from this distribution.
  3. This token is appended to the input.
  4. The process repeats until the model emits a special "end" token.

Illustration of autoregression from a video by 3Blue1Brown

Pre-training/Supervised fine-tuning

The first two LLM training steps are:

  • Pre-training: The most compute-intensive phase, where the model learns from vast amounts of text, essentially resulting in a "Blurry JPEG of the Web".
  • SFT (Supervised Fine-Tuning): The model is trained on chat-like transcripts, mimicking desired final outputs.

In both of these cases, the model is being trained to reproduce the text you’ve fed it as closely as possible.

RLHF/RLAIF/DPO: "Alignment"

Producing coherent text differs from generating text humans like. For example, "Can't let you do that, Dave" and "Just turn the handle counter-clockwise" are both coherent responses to "How do I open the pod bay doors?"

To guide the model towards preferred responses, various techniques are employed. Most involve training an auxiliary "reward model" on human (or AI) preference data. The main model is then trained to generate responses favored by this reward model.

DPO (Direct Preference Optimization) achieves similar results without a separate reward model, using mathematical techniques to directly optimize for preferred outputs. However, this explanation has important caveats:

  1. Current alignment approaches have limited exploration capacity. They mainly "nudge" the model's initial answers towards preferred responses in small steps. This constraint is necessary to prevent model collapse but (as we’ll see) limits the potential for novel problem-solving. The restricted exploration has implications for the model's ability to generate innovative solutions beyond slight variations of its pre-alignment outputs.
  2. Reward models excel at capturing the general "vibe" of good answers. Their construction for alignment purposes isn't optimized for distinguishing between correct and incorrect responses, focusing more on overall quality and appropriateness.

Chain of Thought Prompt-based, and Sampling Techniques

It turns out that simply asking LLMs to explain their reasoning before committing to an answer can improve the chances that they will produce a correct one. This is one of the simplest prompt-based techniques to get the LLM to produce something more accurate, but there are many, many, elaborate prompting schemes that have been used. This can also include generating many responses and sampling from them. As one specific example, “Self-Refine” asks the model to critique and improve its response repeatedly before committing to a final answer.

It turns out that simply asking LLMs to explain their reasoning before committing to an answer can improve the chances that they will produce a correct one. This is one of the simplest prompt-based techniques to get the LLM to produce something more accurate, but many, many, elaborate prompting schemes have been previously used, including generating many responses and sampling from them. One specific example, “Self-Refine”, asks the model to critique and improve its response repeatedly before committing to a final answer.

Reinforcement Learning (RL)

Reinforcement Learning algorithms are a key domain of machine learning, designed to create problem-solving models for specific contexts like chess or Go. A core idea is breaking a context into the following:

  1. An “agent” which is trying to accomplish some goal (defined by a “reward function”)
  2. A “reward function” that gives higher numbers to the agent when it is doing well at achieving the goals, and lower numbers when it is not.
  3. A “state” that represents the current situation the agent is in
  4. A set of “actions” available to the agent that it can take from any given state
  5. A “policy” for selecting an action given the state.

Let’s make this concrete. In chess, the "agent" is the AI playing the game, the "reward function" is the system that assigns points for results like capturing pieces or achieving checkmate. The "state" refers to the current board layout, and the "actions" are the legal moves available. The "policy" is the model or algorithm trained to analyze the board and select the next move.

Reinforcement learning (RL) involves letting the model "figure out" solutions during training, rather than explicitly telling it what to do. The model explores action sequences ("rollouts") and uses feedback to learn which strategies work for future attempts.

Since RL models explore strategies independently, they aren't limited by human input or performance. This allows them to exceed human abilities, as seen in examples like AlphaGo's dominance in Go, which was achieved using RL.

How does this relate to “RL”HF & similar? While they borrow concepts from RL, like reward functions and RL-based loss functions, they lack the full action-sequence exploration needed for general problem solving or superhuman performance. I probably can’t say it better than Andrej Karpathy!

Andrej Karpathy tweet about RLHF

Andrej Karpathy tweet about RLHF

Why o1 is hard: The Problems with Problem Solving

If you're familiar with o1, you might wonder: "Is it just a model trained with SFT or an RLHF-style reward system, using chat transcripts with chain-of-thought prompting?" While the exact training process isn't known, I believe it's safe to say the answer is a clear "no”.

The rest of this post will build intuition on why we can be confident in this, followed by real-world results from similar attempts. The next post will explore informed speculations on what might be happening during o1's training.

A Motivating Example

To build some intuition, let’s use a specific problem:

What's the next term in this sequence? 1, 8, 72, 46, 521, 612, 343, 215, …

I promise you this takes nothing harder than middle-school math to solve. I gave the solution to my very non-technical parents and they were able to follow the solution easily (though neither solved it alone). Take a few minutes to actually try it before reading on. You don’t need to actually get the answer, but try to notice what your process looks like as you attempt it.

Now, let’s look at what ChatGPT-4o says when I give it this, and ask it to explain before answering:

ChatGPT-4o’s answer

To find the next term in the sequence, let's analyze the pattern step by step.

  1. First, consider the differences between consecutive terms:
    • 8 − 1 = 7
    • 72 − 8 = 64
    • 46 − 72 = −26
    • 521 − 46 = 475
    • 612 − 521 = 91
    • 343 − 612 = −269
    • 215 − 343 = −128
  2. Now, let’s look at how these differences change:
    • 7 → 64 (difference of +57)
    • 64 → −26 (difference of -90)
    • −26 → 475 (difference of +501)
    • 475 → 91 (difference of -384)
    • 91 → −269 (difference of -360)
    • −269 → −128 (difference of +141)
  3. Observing this second set of differences:
    • +57, −90, +501, −384, −360, +141

There seems to be an alternating pattern between large positive and negative differences, but they do not follow a simple arithmetic or geometric progression. However, we can attempt to follow the alternating trend of large positive and negative changes.

  1. The last difference was −128, so based on the pattern, we expect the next change to be a large positive number, similar to previous jumps (like +501 or +141).

Without a clear simple pattern emerging, but noticing a general tendency for alternating larger positive and negative shifts, we can make an informed guess that the next number will increase by a large amount.

Thus, the next term in the sequence should increase significantly, potentially by around +400 to +500. Let's estimate the next term as approximately 215 + 400 = 615.

Answer: 615 (with reasoning based on alternating positive and negative shifts of large magnitudes).

This, as should be obvious, is incorrect. Now, here’s what happened when I asked o1-preview this question:

o1-preview

First the model spent about ~1 minute “thinking,” during which time the ChatGPT app summarized various approaches it tried that didn’t work (ex: “Looking at differences”, “Looking for patterns in spelling”, …). It then output the following correct answer:

The sequence you've provided is:

1, 8, 72, 46, 521, 612, 343, 215, ...

At first glance, this sequence doesn't appear to follow a simple arithmetic or geometric pattern. However, by rearranging the digits of each number, a pattern emerges:

  • 1 remains 1.
  • 8 remains 8.
  • 72 rearranged is 27, which is 3^3.
  • 46 rearranged is 64, which is 4^3.
  • 521 rearranged is 125, which is 5^3.
  • 612 rearranged is 216, which is 6^3.
  • 343 remains 343, which is 7^3.
  • 215 rearranged is 512, which is 8^3.

These rearranged numbers correspond to the cubes of the numbers from 1 to 8. Following this pattern, the next cube is 9^3 = 729. Therefore, the next term in the sequence should be a rearrangement of 729.

Possible rearrangements of 729 are:

  • 729
  • 792
  • 927
  • 972
  • 279
  • 297

Any of these rearranged numbers would logically fit as the next term. However, since the sequence seems to prefer certain rearrangements (e.g., 72 for 27, 46 for 64), the most likely next term is 927.

Answer: 927

This is correct. See the 💡 footnote about how I obtained this problem and info on its presence online.

What does “problem-solving” look like?

Unless you are Ramanujan reincarnated, you probably didn’t immediately spot that these were cubed numbers reversed. You likely tried some standard techniques, like checking successive differences, as ChatGPT-4o did. When that didn't work, you moved on to other approaches. If you're persistent or familiar with puzzles like this, you might have thought to reverse the digits. If you’re a certain brand of nerd, you’d recognize the cubes. At that point you’d be able to explain the pattern and follow it to produce the next term.

This reveals some things that seem to be fundamentally important for problem solving, when you can’t just remember an answer you’ve seen before:

  1. Having a toolbox of backgrounds to pull from
  2. Being able to try a variety of things
  3. Being able to recognize when something isn’t working
  4. Being able to subsequently switch tactics to another promising direction
  5. Being able to recognize when something is working so you can continue with that
  6. Being able to leverage that recognition to continue making progress
  7. Being able to recognize when you are “done” and have a valid answer

Unless you have almost all of these things, you almost certainly don’t have a general-problem-solver (though you may have something that is familiar with A LOT of solved problems).

Problems with Current Approaches

Now, let’s explore why SFT and RLHF struggle to meet most of these problem-solving capabilities. SFT is actually quite good at (1), which is why modern LLMs can handle well-known riddles or knowledge-based questions fairly well. Sampling and prompting techniques can help with (2). In narrow domains, models might handle (3) and (4) adequately, which is likely why techniques like “Self-Refine” show modest improvements. However, when trying to broaden (3) and (4) or tackle (5), (6), and (7), SFT starts to fall short. Let’s explore why.

SFT training is essentially similar to memorization. You feed the model many examples, aiming for it to reproduce them exactly. This works well for "knowledge" but not for exploration or productive trial-and-error. You might think, "What if we used SFT with problems that include many false starts before reaching the solution?" Even then, you're still teaching the model to reproduce both wrong and right answers. Worse: you are teaching it to start with wrong answers, which is worse than just training it to give the correct answers from the start.

So what about RLHF and friends? These allow the model to deviate slightly from its "natural" answers during training. Could this help it adopt better techniques when its initial approach isn't working? Unfortunately, the answer is still no, and the reason is more nuanced. Moving from "look at the difference between terms" to "cube the nth number and reverse the digits" isn’t a smooth, continuous change. It likely involves switching from one “local maxima” in the reward function to another. RLHF excels at refining small adjustments, but its training process penalizes anything too far from what the model would have produced originally.

Can we just stop penalizing large deviations? Nope again! If the model makes "big jumps" from its usual output, it quickly diverges into gibberish, exploiting flaws in the reward model which give high rewards to poor answers. Allowing too much deviation causes the model to collapse. Since LLMs are already good at producing coherent text, we keep it "on track" by restricting how far it can drift from reasonable-looking text at any given time.

The Shakespearean Elephant in the Room

“Wait a minute!” I hear some of you saying. I have seen my favorite LLM do things that were DEFINITELY not in the training data. “How do you explain THIS using only memorization or incremental nudges?”

User: Tell me about the purpose of the quadratic formula, as if you were an elephant in a Shakespearean play.

GPT-4o: Ah, good sir! I, Elephanto the Wise, with trunk aloft and ears broad as the sails of yon ship, shall now impart to thee the purpose of the quadratic formula!

'Tis a marvel, a tool most divine, That seeks the roots, the very core Of equations squared in line. For when thou hast such a problem vexed— A quadratic, of form perplexed: Ax² + Bx + C, dost thou see? The formula dost come to thee!

x equals negative B, Plus or minus a square root's plea— Of , from which subtract we must Four A and C in just and trust. Divide the whole by twice of A, And thou shalt find where answers lay!

Lo, its purpose clear, it aids with grace, To find where parabolas meet their place On the plane, where they cross the line so true— At x, at x, the roots come into view!

So now, with trunk raised high and proud, I trumpet forth this knowledge, loud! Quadratics dost bow 'fore such a gem, And with this formula, we conquer them!

A less playful version of this objection might be “how models can concoct valid code in contexts they haven’t seen before?” “How can models generate valid code in unfamiliar contexts?” My understanding is that the transformer architecture, combined with current training methods, excels at interpolating between multiple memorized concepts activated by earlier context. This allows the model to produce outputs that are consistent with concepts like "Shakespearean," "elephant," and "quadratic formula.”

Compare this to the "next number in the sequence" problem mentioned earlier. In that case, nothing is prompting the model to incorporate knowledge of cubed numbers or digit rearrangements into its output (see❓footnote). It needs a different set of problem-solving tools to “discover” those connections.

Proof in the Pudding

The previous motivations for the shortcomings of non-real-RL approaches to problem-solving have been intuitive. Now, let me provide something more convincing: data. A crucial aspect of our outlined problem-solving approach is the ability to course-correct when things aren’t working. A recent paper from Google DeepMind quantifies this for several models and approaches:

  • Accuract@t1: how often a model produces a correct result on its first attempt
  • Accuracy@t2: how often it produces a correct result on its second attempt
  • 𝚫(t1, t2): the difference between its accuracy on attempt 1 vs attempt 2—in other words whether it was better on a second try after its first try
  • 𝚫i→c(t1, t2) : how often it moved to a correct response given that it started on an incorrect one
  • 𝚫i→c(t1, t2) : how often it moved to an incorrect answer given that it had already started on a correct one

The dataset for these results consists of math problems. Here, "Base model" refers to Gemini 1.5 Flash, unmodified and using a chain-of-thought prompt. Self-Refine is Gemini Flash utilizing a technique that prompts the model to iteratively critique and improve its answer. STaR and Pair-SFT both employ SFT, albeit in sophisticated ways that sample model generations. As for SCoRe, it appears to be the only approach that performs reasonably well at self-correction on this dataset. Spoiler: it incorporates RL, albeit in a limited manner.

Table 3 from https://arxiv.org/pdf/2409.12917

Footnotes

💡 I got this riddle by perusing https://oeis.org/wiki/Puzzle_Sequences puzzles, then finding one that was fairly easy to create a variation on. Searching with DuckDuckGo, I found the “riddle terms” of this “reversed cube” sequence in only 1 result, so it may be present in some training data, albeit with very low volume. Adding the next term to the search (the answer), I find no results. The sequence does appear on OEIS but it’s fair to say the model would have a vanishingly small amount of representation for this exact problem in its training. I also asked for an o1-preview (in separate chat sessions) of the version of this using reversed squares, reversed 4th powers, and reversed 5th powers. A similar search gave no DuckDuckGo matches for the 4th or 5th power sequences, though both do appear on OEIS.

❓ Tellingly, ChatGPT-4o correctly identifies 297 as the answer if I modify my prompt to be: “What's the next term in this sequence? 1, 8, 72, 46, 521, 612, 343, 215, ... Explain your reasoning before answering. Consider cubed numbers and digit rearrangements.”

AI Data Platform

A comprehensive AI platform

Dataset Curation

Generate high-quality datasets.

LLM Fine-Tuning

Customize LLMs to your specific use case.

LLM Playground

Vibe-check 30+ SOTA LLMs at once.

LLM Evaluation

Compare LLMs on your entire eval set.

Accelerate your AI workflows with Airtrain's comprehensive suite of tools. From dataset curation to LLM fine-tuning and evaluation.

Unlock your data, control your AI.