PPO for Language Models: Adapting RL to Text Generation

Michael BrenndoerferDecember 28, 202543 min read

Learn how PPO applies to language models. Covers policy mapping, token action spaces, KL divergence penalties, and advantage estimation for RLHF.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

PPO for Language Models

In the previous chapter, we explored the PPO algorithm as a general-purpose policy gradient method with clipped objectives and value function estimation. Now we face a crucial question: how do we apply these ideas to language models? The translation is difficult because traditional reinforcement learning and text generation have fundamental differences. Language models were designed as next-token predictors, not as agents acting in environments. They were trained to model the statistical patterns of human language, not to maximize cumulative rewards over time. Yet with some careful reframing, we can view text generation through the lens of sequential decision-making and apply PPO to steer models toward human-preferred outputs.

This chapter connects reinforcement learning to the structure of language generation. We will see how an LLM naturally serves as a stochastic policy, why the vocabulary forms a massive discrete action space, and how rewards propagate through generated sequences. Most importantly, we will understand why constraining the policy to stay close to its original behavior is essential for stable training. Each of these concepts builds on the last, forming a complete picture of how PPO transforms language model behavior.

The Language Model as a Policy

In reinforcement learning, a policy maps states to action probabilities. Given a state ss, the policy π(as)\pi(a|s) tells us the probability of taking action aa. This definition describes decision-making: the agent observes its state, checks its policy, and chooses an action. Language models do the same: given a context, the model outputs a probability distribution over the next token. The parallel is exact.

To see this clearly, consider what happens when you type a prompt. The model processes your input, builds internal representations through its transformer layers, and produces a probability distribution over its vocabulary. This distribution assigns higher probabilities to tokens that would naturally continue the text and lower probabilities to tokens that would seem out of place. When the model generates a response, it samples from this distribution (or selects greedily), appends the chosen token to the context, and repeats the process. Each step involves observing a state and selecting an action according to a probability distribution, which is exactly what a policy does.

Let x=(x1,x2,,xn)x = (x_1, x_2, \ldots, x_n) denote the input prompt and y=(y1,y2,,yT)y = (y_1, y_2, \ldots, y_T) the generated response. At each generation step tt, the language model computes:

πθ(ytx,y<t)\pi_\theta(y_t | x, y_{<t})

where:

  • πθ\pi_\theta: the policy defined by model parameters θ\theta
  • yty_t: the token generated at the current step tt
  • xx: the input prompt sequence
  • y<ty_{<t}: the sequence of tokens generated prior to step tt, (y1,,yt1)(y_1, \ldots, y_{t-1})

This represents the probability of generating token yty_t given the full context of the prompt and previous tokens. The notation emphasizes that the policy depends on everything that came before: the original prompt establishes the task, and each previously generated token shapes what should come next. The model's parameters θ\theta encode the learned patterns that determine how context maps to token probabilities.

Comparing language models to standard RL shows:

  • State sts_t: The concatenation of the prompt and all tokens generated so far, (x,y1,,yt1)(x, y_1, \ldots, y_{t-1})
  • Action ata_t: The next token to generate, yty_t
  • Policy πθ(atst)\pi_\theta(a_t|s_t): The LLM's softmax output distribution over the vocabulary
  • Trajectory τ\tau: The complete prompt-response pair (x,y)(x, y)

The autoregressive generation process that produces a response is exactly a policy rollout. Starting from the initial state s1=xs_1 = x, we sample actions according to the policy, each action extends the state, and we continue until generating a stop token. This equivalence is not just a useful metaphor; it is a precise mathematical correspondence that allows us to apply policy gradient methods to language generation. Policy optimization theorems and techniques apply to steering language models toward desired behaviors.

The Token Action Space

The action space in language generation is the model's vocabulary. This discrete set of tokens defines every possible action the policy can take at each step. Unlike continuous control problems where actions might be forces or velocities, or even discrete games where actions represent button presses, the language model's action space consists of linguistic units. These might be complete words, word fragments, punctuation marks, or special tokens that signal the end of generation.

For modern language models, this vocabulary typically contains between 30,000 and 100,000 tokens. The exact size depends on the tokenization algorithm used during pretraining. Byte-pair encoding, the most common approach, creates vocabularies that balance coverage of common words with the ability to represent rare words through subword decomposition. The vocabulary must be large enough to efficiently represent common text patterns while remaining small enough for the softmax computation to be tractable.

This scale is much larger than traditional RL domains. In classic control, action spaces are small. For example, a robot arm might have six joints. In Atari games, agents choose among roughly 18 discrete actions representing joystick directions and button combinations. Language models must select from tens of thousands of possible tokens at every single step. This represents an increase of three to four orders of magnitude in the number of discrete choices.

Out[2]:
Visualization
Comparison of action space sizes across common reinforcement learning domains. While traditional benchmarks like Atari or CartPole involve fewer than 100 discrete actions, language models must select from vocabularies of 50,000 tokens or more, presenting a significantly higher-dimensional decision problem.
Comparison of action space sizes across common reinforcement learning domains. While traditional benchmarks like Atari or CartPole involve fewer than 100 discrete actions, language models must select from vocabularies of 50,000 tokens or more, presenting a significantly higher-dimensional decision problem.

This has significant implications:

  • Exploration is implicit: With such a large action space, the model cannot systematically try all options. There is no way to enumerate the consequences of every possible token choice. Instead, exploration emerges from the inherent stochasticity of sampling from high-entropy distributions during generation. When the model is uncertain, it assigns probability mass to many tokens, and sampling naturally explores these alternatives.
  • Credit assignment is diffuse: When a response receives a reward, determining which specific token choices contributed to that reward is challenging. Did the response succeed because of word choice in the third sentence, or the overall argument structure? The signal must somehow propagate back through dozens or hundreds of individual decisions.
  • Probability mass spreads thin: Even well-trained models may assign relatively small probabilities to any single token, making log probability computations numerically sensitive. When probabilities are small, their logarithms become large negative numbers, requiring careful numerical handling.

Despite these challenges, the discrete nature of the action space simplifies some aspects of PPO. We can compute exact action probabilities rather than approximating them, as we would need to do in continuous action spaces. We can directly enumerate the KL divergence between policies by summing over vocabulary positions. The softmax function gives us a proper probability distribution that sums to one, avoiding the density estimation challenges that arise with continuous distributions.

In[3]:
Code
import torch
import torch.nn.functional as F

# Illustrate the action space concept
vocab_size = 32000  # Typical vocabulary size
hidden_dim = 768


# A simplified "policy head" that maps hidden states to action probabilities
class PolicyHead(torch.nn.Module):
    def __init__(self, hidden_dim, vocab_size):
        super().__init__()
        self.lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=False)

    def forward(self, hidden_state):
        # hidden_state: (batch, seq_len, hidden_dim)
        logits = self.lm_head(hidden_state)  # (batch, seq_len, vocab_size)
        return logits

    def get_action_probs(self, hidden_state, temperature=1.0):
        logits = self.forward(hidden_state)
        return F.softmax(logits / temperature, dim=-1)


policy_head = PolicyHead(hidden_dim, vocab_size)
num_params = sum(p.numel() for p in policy_head.parameters())
Out[4]:
Console
Action space size: 32,000 possible tokens
Policy head parameters: 24,576,000

Each position in the generated sequence requires selecting from this massive action space. A 100-token response involves 100 sequential decisions, each choosing among 32,000 options. Even a short response spans a massive search space. This is why language generation uses learned policies to find good text efficiently.

States as Growing Contexts

The state in language generation grows with each action. Unlike a game where the agent might return to previously visited states, or a robot navigation task where the agent can revisit locations, language generation always moves forward. Each generated token permanently extends the context. The state at step t+1t+1 contains everything from state tt plus the newly generated token.

Formally, the state at step tt is:

st=(x1,x2,,xn,y1,y2,,yt1)s_t = (x_1, x_2, \ldots, x_n, y_1, y_2, \ldots, y_{t-1})

where:

  • sts_t: the state at timestep tt
  • xix_i: the ii-th token of the prompt
  • nn: the number of tokens in the prompt
  • yjy_j: the jj-th generated token

This formulation captures the essential nature of autoregressive generation. The state is not a compact summary of the situation; it is the complete history of what has been written. The model must condition on all of this information to decide what should come next. A word that appeared ten sentences ago might be crucial for maintaining coherence, while a word from three tokens ago might determine grammatical constraints on the current position.

This means the state space is infinite and largely unique to each trajectory. Two different prompts lead to entirely different state spaces, and even the same prompt with different generation paths explores different states. Unlike board games where states might repeat (the same chess position can arise from different move orders) or continuous control where the system might return to similar configurations, language generation creates fresh territory with every token. Each state is a unique point in the space of possible text prefixes.

The growing state affects implementation. The model must process longer sequences as it generates tokens. Each new token requires attending to all previous tokens, which increases computation quadratically. This is where techniques like KV caching, which we will cover in Part XXVIII, become essential for efficient inference and training. By storing key and value projections, the model avoids recomputing them. This changes complexity from quadratic to linear.

In[5]:
Code
def demonstrate_state_growth():
    """Show how states evolve during generation."""
    prompt_tokens = [101, 2054, 2003, 1996, 3007]  # "What is the capital"

    states = []
    generated = []

    # Simulate a generation trajectory
    response_tokens = [1997, 2605, 1029, 102]  # "of France? [SEP]"

    for t, token in enumerate(response_tokens):
        current_state = prompt_tokens + generated
        states.append(
            {
                "step": t,
                "state_length": len(current_state),
                "state": current_state.copy(),
                "action": token,
            }
        )
        generated.append(token)

    return states


trajectory = demonstrate_state_growth()
Out[6]:
Console
State evolution during generation:
--------------------------------------------------
Step 0: state length = 5, action = 1997
Step 1: state length = 6, action = 2605
Step 2: state length = 7, action = 1029
Step 3: state length = 8, action = 102

Final trajectory length: 4 steps
Out[7]:
Visualization
Monotonic state growth during the autoregressive generation process. Unlike environments with fixed state spaces or cycles, each generated token permanently extends the context, requiring the policy to condition on an ever-increasing history of previous decisions.
Monotonic state growth during the autoregressive generation process. Unlike environments with fixed state spaces or cycles, each generated token permanently extends the context, requiring the policy to condition on an ever-increasing history of previous decisions.

The generated trajectory illustrates how the context (state) expands with each step. Even a short response creates a sequence of unique states, as the growing history fundamentally changes the input to the policy at every decision point. At step 0, the policy sees only the prompt. At step 1, it sees the prompt plus one generated token. By the final step, it sees the entire conversation history. This expansion means that the policy faces a different decision problem at every timestep, even though the underlying question remains the same.

Reward Assignment in Sequential Generation

Reward models and RL algorithms use rewards differently, which creates challenges for PPO. The reward model, as we discussed in the chapter on Reward Modeling, takes a complete prompt-response pair and outputs a single scalar score. It evaluates the response as a whole, considering factors like helpfulness, coherence, accuracy, and safety. But PPO operates on trajectories with per-timestep rewards. The algorithm expects to receive a reward signal at each step, allowing it to compute advantages and update the policy accordingly.

The standard approach assigns the reward to the final token:

rt={R(x,y)if t=T (final token)0otherwiser_t = \begin{cases} R(x, y) & \text{if } t = T \text{ (final token)} \\ 0 & \text{otherwise} \end{cases}

where:

  • rtr_t: the reward assigned at step tt
  • R(x,y)R(x, y): the scalar score from the reward model for the complete response
  • TT: the length of the generated sequence (final step)
  • xx: the input prompt
  • yy: the complete generated response

This approach reflects that we only know how good a response is once it is complete. We usually cannot evaluate a partial response because quality depends on the full response and the final answer. So we wait until generation finishes, compute the reward, and assign it to the final step.

Sparse rewards make credit assignment difficult. The policy must learn which of its many token choices contributed to the final reward. PPO addresses this through its advantage estimation, but the challenge remains significant. A response might receive a low reward because of a single poor word choice, but that signal must propagate back through dozens of preceding tokens. The value function must learn to predict, from any intermediate state, what the expected final reward will be. This prediction is what allows advantages to differentiate between tokens: some tokens lead to states with high expected rewards, others to states with lower expectations.

In[8]:
Code
import numpy as np


def assign_rewards(response_length, final_reward):
    """
    Standard reward assignment: all reward at final step.

    Args:
        response_length: Number of tokens in the response
        final_reward: Scalar reward from the reward model

    Returns:
        Array of per-token rewards
    """
    rewards = np.zeros(response_length)
    rewards[-1] = final_reward  # Assign all reward to final token
    return rewards


def compute_returns(rewards, gamma=1.0):
    """
    Compute discounted returns from rewards.

    With gamma=1.0 (common in RLHF), each token's return
    equals the final reward.
    """
    T = len(rewards)
    returns = np.zeros(T)
    running_return = 0

    for t in reversed(range(T)):
        running_return = rewards[t] + gamma * running_return
        returns[t] = running_return

    return returns
In[9]:
Code
# Demonstrate reward assignment
response_len = 8
reward = 0.75

rewards = assign_rewards(response_len, reward)
returns = compute_returns(rewards)
Out[10]:
Console
Per-token rewards: [0.   0.   0.   0.   0.   0.   0.   0.75]
Per-token returns: [0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75]

Every token receives return = 0.75 (the final reward)
Out[11]:
Visualization
Comparison of sparse terminal rewards and their corresponding returns. The reward model evaluates the complete response to provide a single scalar score at the final token, which is then propagated back as a constant return for all tokens when the discount factor gamma = 1.0.
Comparison of sparse terminal rewards and their corresponding returns. The reward model evaluates the complete response to provide a single scalar score at the final token, which is then propagated back as a constant return for all tokens when the discount factor gamma = 1.0.
Notebook output

With a discount factor of γ=1.0\gamma = 1.0, which is standard in RLHF, every token in the sequence receives the same return, equal to the final reward. This uniform signal might seem uninformative at first glance. If every token gets the same return, how can the algorithm distinguish good tokens from bad ones? The answer lies in the advantage function, which compares the actual return to the expected return under the value function. A token that leads to a higher-than-expected return receives a positive advantage, while a token that leads to a lower-than-expected return receives a negative advantage. This differential signal, created by the value function's predictions, is what enables learning even with sparse rewards.

The Critical Role of the KL Penalty

The most important adaptation for language model PPO is the KL divergence penalty. This constraint measures how far the policy drifts from its starting point and serves several purposes. Without this constraint, the optimized policy can diverge dramatically from the original model, often finding degenerate solutions that maximize reward without producing genuinely useful responses.

Reward Hacking Revisited

As we discussed in the chapter on Reward Hacking, optimizing a proxy reward (the learned reward model) rather than true human preferences creates opportunities for exploitation. The KL penalty is our primary defense against this failure mode.

Without the KL penalty, the policy often finds degenerate solutions. The policy is free to move anywhere in the space of possible token distributions. If the reward model has any exploitable patterns, any shortcuts that yield high scores without genuine quality, the unconstrained policy will find them. It might learn to generate repetitive phrases that the reward model scores highly. It might produce outputs that superficially resemble good responses while lacking substance. It might drift so far from natural language that it generates text no human would write. The KL penalty prevents these failure modes by keeping the policy anchored to the reference distribution.

The KL penalty modifies the reward at each timestep:

r~t=rtβKLt\tilde{r}_t = r_t - \beta \cdot \text{KL}_t

where:

  • r~t\tilde{r}_t: the shaped reward at step tt used for PPO training
  • rtr_t: the original sparse reward
  • β\beta: the KL penalty coefficient
  • KLt\text{KL}_t: the KL divergence contribution at this step

The KL term is defined as the log-ratio between the current policy and a reference policy:

KLt=logπθ(ytst)πref(ytst)\text{KL}_t = \log \frac{\pi_\theta(y_t | s_t)}{\pi_{\text{ref}}(y_t | s_t)}

where:

  • πθ(ytst)\pi_\theta(y_t | s_t): the probability of the chosen token under the current policy
  • πref(ytst)\pi_{\text{ref}}(y_t | s_t): the probability of the chosen token under the reference policy
  • yty_t: the token generated at step tt
  • sts_t: the context (state) at step tt

This formula is easy to interpret. When the current policy assigns higher probability to a token than the reference did, the log ratio is positive, and the policy is penalized. When the current policy assigns lower probability, the log ratio is negative, and the policy receives a bonus. The net effect is that the policy is discouraged from making dramatic probability changes in either direction. It can shift probabilities to improve reward, but only within bounds.

The reference policy πref\pi_{\text{ref}} is typically the model after supervised fine-tuning (SFT) but before any RL training. This anchor prevents the policy from drifting into regions of token space that the original model considered highly unlikely. The SFT model represents our best current understanding of how to generate helpful, coherent text. By constraining the RL policy to stay near this baseline, we ensure that the optimized model retains the linguistic competence learned during pretraining and SFT.

The coefficient β\beta controls the strength of this constraint. Higher values keep the policy closer to the reference but limit learning; the policy cannot deviate much even when doing so would improve reward. Lower values allow more exploration but risk instability and reward hacking; the policy might find degenerate solutions that maximize reward while producing poor text. Success in RLHF depends on finding the right balance. We will explore the mathematical properties and tuning of this penalty in detail in the upcoming chapter on KL Divergence Penalty.

In[12]:
Code
def compute_kl_penalty(log_probs_policy, log_probs_ref):
    """
    Compute per-token KL divergence between policy and reference.

    This is a simplified approximation using only the chosen actions,
    not the full distribution over vocabulary.

    Args:
        log_probs_policy: Log probabilities under current policy
        log_probs_ref: Log probabilities under reference policy

    Returns:
        Per-token KL divergence estimates
    """
    # KL(policy || ref) ≈ log(policy(a)) - log(ref(a)) for chosen action a
    # This is an approximation; true KL sums over all actions
    kl = log_probs_policy - log_probs_ref
    return kl


def apply_kl_reward_shaping(rewards, log_probs_policy, log_probs_ref, beta=0.1):
    """
    Modify rewards with KL penalty.

    Args:
        rewards: Original per-token rewards (typically sparse)
        log_probs_policy: Log probs of chosen tokens under policy
        log_probs_ref: Log probs of chosen tokens under reference
        beta: KL penalty coefficient

    Returns:
        Modified rewards with KL penalty applied
    """
    kl = compute_kl_penalty(log_probs_policy, log_probs_ref)
    shaped_rewards = rewards - beta * kl
    return shaped_rewards, kl
In[13]:
Code
# Demonstrate KL penalty effect
np.random.seed(42)
seq_len = 6

# Simulate log probabilities
log_probs_policy = np.array([-1.2, -0.8, -2.1, -1.5, -0.9, -1.8])
log_probs_ref = np.array([-1.5, -0.9, -1.8, -1.6, -1.2, -2.0])

# Original sparse reward
original_rewards = assign_rewards(seq_len, 0.8)

# Apply KL shaping
shaped_rewards, kl_values = apply_kl_reward_shaping(
    original_rewards, log_probs_policy, log_probs_ref, beta=0.1
)

total_kl = kl_values.sum()
total_original = original_rewards.sum()
total_shaped = shaped_rewards.sum()
Out[14]:
Console
Per-token analysis:
------------------------------------------------------------
Token 0: original_r=0.00, KL=0.300, shaped_r=-0.030
Token 1: original_r=0.00, KL=0.100, shaped_r=-0.010
Token 2: original_r=0.00, KL=-0.300, shaped_r=0.030
Token 3: original_r=0.00, KL=0.100, shaped_r=-0.010
Token 4: original_r=0.00, KL=0.300, shaped_r=-0.030
Token 5: original_r=0.80, KL=0.200, shaped_r=0.780
------------------------------------------------------------
Total KL penalty: 0.700
Total original reward: 0.80
Total shaped reward: 0.730
Out[15]:
Visualization
Effect of KL reward shaping on the training signal. The penalty transforms the sparse terminal reward into dense per-token values by comparing the policy's log probabilities to a reference model, anchoring the optimization to the starting distribution and providing feedback at every step.
Effect of KL reward shaping on the training signal. The penalty transforms the sparse terminal reward into dense per-token values by comparing the policy's log probabilities to a reference model, anchoring the optimization to the starting distribution and providing feedback at every step.
Notebook output
Notebook output

Notice how the KL penalty transforms the reward signal. The original sparse reward only provides signal at the final token; all intermediate tokens receive zero reward. After KL shaping, every token receives a reward component based on how much the policy diverges from the reference. Tokens where the policy assigns higher probability than the reference receive penalties (the KL term is positive, so it subtracts from the reward). Tokens where the policy assigns lower probability than the reference receive bonuses (the KL term is negative, so subtracting it adds to the reward). This transformation converts the sparse terminal reward into a dense per-token signal that guides the policy at every step.

PPO Objective for Language Models

Combining all these elements, the PPO objective for language models takes the following form. The PPO objective maximizes the expected advantage while keeping the policy close to its starting behavior. The clipping mechanism prevents destructively large updates that could destabilize training. For a batch of prompt-response pairs, we compute:

LPPO(θ)=E(x,y)πθold[t=1Tmin(ρt(θ)A^t,clip(ρt(θ),1ϵ,1+ϵ)A^t)]\mathcal{L}_{\text{PPO}}(\theta) = \mathbb{E}_{(x,y) \sim \pi_{\theta_{\text{old}}}} \left[ \sum_{t=1}^{T} \min\left( \rho_t(\theta) \hat{A}_t, \text{clip}(\rho_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t \right) \right]

where:

  • θ\theta: the parameters of the language model being optimized
  • E(x,y)πθold\mathbb{E}_{(x,y) \sim \pi_{\theta_{\text{old}}}}: the expectation over trajectories generated by the policy that collected the data
  • ρt(θ)\rho_t(\theta): the probability ratio between the current and old policies at step tt
  • A^t\hat{A}_t: the estimated advantage at step tt, indicating how much better the chosen action was compared to the average
  • TT: the length of the generated sequence
  • ϵ\epsilon: the clipping parameter (typically 0.1 or 0.2) that defines the trust region

The min operator serves a crucial purpose in this objective. It takes the conservative lower bound between the unclipped and clipped objectives, ensuring that updates cannot be too aggressive. If the advantage is positive, indicating that the chosen action was better than expected, clipping limits how much we increase the probability of the action. If the advantage is negative, indicating the action was worse than expected, clipping limits how much we decrease the probability. This prevents large updates that could collapse the policy.

The probability ratio ρt(θ)\rho_t(\theta) is defined as:

ρt(θ)=πθ(ytx,y<t)πθold(ytx,y<t)\rho_t(\theta) = \frac{\pi_\theta(y_t | x, y_{<t})}{\pi_{\theta_{\text{old}}}(y_t | x, y_{<t})}

where:

  • πθ(ytx,y<t)\pi_\theta(y_t | x, y_{<t}): the probability of the token yty_t under the current policy (being updated)
  • πθold(ytx,y<t)\pi_{\theta_{\text{old}}}(y_t | x, y_{<t}): the probability of the same token under the policy that collected the data (frozen for this update step)
  • yty_t: the token generated at step tt
  • xx: the input prompt
  • y<ty_{<t}: the sequence of tokens generated prior to step tt

This ratio measures how the policy has changed since the data was collected. A ratio of 1.0 means the policy assigns exactly the same probability as before. A ratio greater than 1 means the policy now assigns higher probability to this token. A ratio less than 1 means the probability has decreased. By basing updates on this ratio rather than directly on log probabilities, PPO can perform multiple gradient steps on the same batch of data without the policy drifting too far from the data-collection policy.

The advantage A^t\hat{A}_t is estimated using GAE (Generalized Advantage Estimation) applied to the KL-shaped rewards:

A^t=l=0Tt(γλ)lδt+l\hat{A}_t = \sum_{l=0}^{T-t} (\gamma \lambda)^l \delta_{t+l}

where:

  • γ\gamma: the discount factor (usually near 1.0 for RLHF)
  • λ\lambda: the GAE smoothing parameter (typically 0.95)
  • TT: the length of the generated sequence
  • δt+l\delta_{t+l}: the temporal difference (TD) error at step t+lt+l

This summation calculates an exponentially weighted average of future TD errors. The parameter λ\lambda controls the bias-variance trade-off: higher values rely more on observed returns (reducing bias but increasing variance), while lower values rely more on value estimates (reducing variance but potentially introducing bias from imperfect value predictions). Setting λ=1\lambda = 1 recovers Monte Carlo returns, while λ=0\lambda = 0 gives pure one-step TD learning. The typical value of 0.95 balances these considerations.

The TD error δt\delta_t measures the surprise at each step: the difference between the observed reward plus the estimated value of the next state, and the estimated value of the current state:

δt=r~t+γV(st+1)V(st)\delta_t = \tilde{r}_t + \gamma V(s_{t+1}) - V(s_t)

where:

  • r~t\tilde{r}_t: the KL-shaped reward at step tt
  • γ\gamma: the discount factor
  • V(st)V(s_t): the value function estimate for state sts_t
  • V(st+1)V(s_{t+1}): the value function estimate for the next state (defined as 0 if t=Tt=T)

The TD error captures whether the transition was better or worse than expected. If the reward plus next-state value exceeds the current-state value, the transition was unexpectedly good, and the TD error is positive. If the sum falls short, the transition was unexpectedly bad, and the TD error is negative. These signals, accumulated through GAE, produce advantage estimates that indicate which actions led to better outcomes than the value function predicted.

In[16]:
Code
import numpy as np


def compute_gae(rewards, values, gamma=1.0, lam=0.95):
    """
    Compute Generalized Advantage Estimation.

    Args:
        rewards: Per-token (shaped) rewards
        values: Value function estimates for each state
        gamma: Discount factor
        lam: GAE lambda parameter

    Returns:
        advantages: GAE advantage estimates
        returns: Target returns for value function
    """
    T = len(rewards)
    advantages = np.zeros(T)
    last_gae = 0

    # Assume terminal state has value 0
    next_value = 0

    for t in reversed(range(T)):
        delta = rewards[t] + gamma * next_value - values[t]
        advantages[t] = last_gae = delta + gamma * lam * last_gae
        next_value = values[t]

    returns = advantages + values
    return advantages, returns


def compute_ppo_loss(
    log_probs_new, log_probs_old, advantages, clip_epsilon=0.2
):
    """
    Compute clipped PPO policy loss.

    Args:
        log_probs_new: Log probs under current policy
        log_probs_old: Log probs under policy that collected data
        advantages: Advantage estimates (should be normalized)
        clip_epsilon: Clipping parameter

    Returns:
        PPO policy loss (to be maximized/negated for minimization)
    """
    # Probability ratio
    ratio = np.exp(log_probs_new - log_probs_old)

    # Clipped ratio
    clipped_ratio = np.clip(ratio, 1 - clip_epsilon, 1 + clip_epsilon)

    # PPO surrogate objectives
    surrogate1 = ratio * advantages
    surrogate2 = clipped_ratio * advantages

    # Take minimum (conservative update)
    loss = np.minimum(surrogate1, surrogate2)

    return loss.mean(), ratio, clipped_ratio
In[17]:
Code
# Demonstrate PPO loss computation
np.random.seed(123)

# Simulate a trajectory
seq_len = 5
shaped_rewards = np.array([0.02, -0.05, 0.03, -0.01, 0.85])  # KL-shaped
values = np.array([0.4, 0.5, 0.45, 0.6, 0.7])  # Value estimates

# Compute advantages
advantages, returns = compute_gae(shaped_rewards, values)

# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

# Log probabilities (old policy collected data, new policy being optimized)
log_probs_old = np.array([-1.2, -0.9, -1.5, -1.1, -0.8])
log_probs_new = np.array([-1.1, -0.85, -1.6, -1.0, -0.75])

# Compute loss
loss, ratio, clipped_ratio = compute_ppo_loss(
    log_probs_new, log_probs_old, advantages
)
Out[18]:
Console
PPO Loss Computation:
--------------------------------------------------
Advantages (normalized): [ 1.022 -0.099  1.171 -0.611 -1.484]
Probability ratios: [1.105 1.051 0.905 1.105 1.051]
Clipped ratios: [1.105 1.051 0.905 1.105 1.051]

PPO loss (to maximize): -0.0299
Out[19]:
Visualization
PPO clipping mechanism for positive and negative advantage scenarios. By limiting the probability ratio $\rho$ within a trust region (typically [0.8, 1.2]), the objective prevents excessively large policy updates that could destabilize training while still allowing the model to learn from high-advantage actions.
PPO clipping mechanism for positive and negative advantage scenarios. By limiting the probability ratio $\rho$ within a trust region (typically [0.8, 1.2]), the objective prevents excessively large policy updates that could destabilize training while still allowing the model to learn from high-advantage actions.
Notebook output

The positive advantage at the final step reflects the high reward, while negative values indicate steps that yielded lower-than-expected value. The probability ratios stay within the trust region defined by the clipping parameter, demonstrating how PPO maintains stability. When a ratio exceeds 1 + ϵ\epsilon or falls below 1 - ϵ\epsilon, the clipped ratio takes over, preventing the gradient from pushing the policy further in that direction.

Implementation: PPO Training Step

Let's now assemble a more complete implementation showing how these pieces fit together in a training step. This simplified version captures the essential structure while omitting some production details like distributed training, gradient accumulation, and advanced memory management. This example shows the flow of data and computation.

In[20]:
Code
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Tuple


@dataclass
class PPOConfig:
    """Configuration for PPO training."""

    clip_epsilon: float = 0.2
    kl_coef: float = 0.1
    value_coef: float = 0.5
    gamma: float = 1.0
    lam: float = 0.95


class LanguageModelPPO:
    """
    PPO trainer for language models.

    This simplified implementation demonstrates the key components
    without the full complexity of production systems.
    """

    def __init__(self, policy_model, ref_model, value_model, config: PPOConfig):
        """
        Args:
            policy_model: The LLM being optimized
            ref_model: Frozen reference model for KL computation
            value_model: Value function (often shares backbone with policy)
            config: PPO hyperparameters
        """
        self.policy = policy_model
        self.ref = ref_model
        self.value = value_model
        self.config = config

        # Freeze reference model
        for param in self.ref.parameters():
            param.requires_grad = False

    def compute_log_probs(self, model, input_ids, response_ids):
        """
        Compute log probabilities for response tokens.

        Args:
            model: Language model
            input_ids: Full sequence (prompt + response)
            response_ids: Just the response tokens

        Returns:
            Log probabilities for each response token
        """
        with torch.no_grad() if model == self.ref else torch.enable_grad():
            outputs = model(input_ids)
            logits = outputs.logits if hasattr(outputs, "logits") else outputs

            # Get logits for positions that predict response tokens
            # (shifted by 1 for autoregressive prediction)
            response_start = input_ids.shape[1] - response_ids.shape[1]
            response_logits = logits[:, response_start - 1 : -1, :]

            # Compute log probabilities
            log_probs = F.log_softmax(response_logits, dim=-1)

            # Gather log probs for actual tokens
            token_log_probs = torch.gather(
                log_probs, dim=-1, index=response_ids.unsqueeze(-1)
            ).squeeze(-1)

        return token_log_probs

    def compute_rewards_and_advantages(
        self,
        reward_scores: torch.Tensor,
        log_probs_policy: torch.Tensor,
        log_probs_ref: torch.Tensor,
        values: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute KL-shaped rewards and GAE advantages.
        """
        batch_size, seq_len = log_probs_policy.shape

        # KL penalty per token
        kl = log_probs_policy - log_probs_ref

        # Initialize rewards (sparse: only at final token)
        rewards = torch.zeros_like(log_probs_policy)
        rewards[:, -1] = reward_scores

        # Apply KL shaping
        shaped_rewards = rewards - self.config.kl_coef * kl

        # Compute GAE
        advantages = torch.zeros_like(shaped_rewards)
        last_gae = torch.zeros(batch_size, device=shaped_rewards.device)

        for t in reversed(range(seq_len)):
            next_value = (
                values[:, t + 1]
                if t < seq_len - 1
                else torch.zeros_like(last_gae)
            )
            delta = (
                shaped_rewards[:, t]
                + self.config.gamma * next_value
                - values[:, t]
            )
            advantages[:, t] = last_gae = (
                delta + self.config.gamma * self.config.lam * last_gae
            )

        returns = advantages + values

        return advantages, returns, kl

    def ppo_loss(
        self,
        log_probs_new: torch.Tensor,
        log_probs_old: torch.Tensor,
        advantages: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute clipped PPO policy loss.
        """
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (
            advantages.std() + 1e-8
        )

        # Probability ratio
        ratio = torch.exp(log_probs_new - log_probs_old)

        # Clipped surrogate
        clipped_ratio = torch.clamp(
            ratio, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon
        )

        # Loss (negative because we minimize)
        policy_loss = -torch.min(
            ratio * advantages, clipped_ratio * advantages
        ).mean()

        return policy_loss

    def value_loss(
        self, values: torch.Tensor, returns: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute value function loss.
        """
        return F.mse_loss(values, returns)
In[21]:
Code
# Initialize configuration
config = PPOConfig()
Out[22]:
Console
LanguageModelPPO trainer initialized with components:
  - Policy model (being optimized)
  - Reference model (frozen, for KL computation)
  - Value model (critic)

Key hyperparameters:
  - clip_epsilon: 0.2
  - kl_coef (β): 0.1
  - value_coef: 0.5
  - GAE λ: 0.95

These hyperparameters define the constraints for the optimization. The clipping epsilon and KL coefficient are particularly critical for preventing the model from collapsing or drifting too far from its original capabilities. The value coefficient balances the policy and value function losses, and the GAE lambda controls the bias-variance trade-off in advantage estimation.

A Complete Training Loop

The following example demonstrates how PPO training proceeds at a high level. We use mock components to illustrate the data flow without requiring actual large models. This demonstration shows the essential rhythm of PPO training: generate responses, score them, compute advantages, and update the policy.

In[23]:
Code
import torch.nn as nn


class MockLanguageModel(nn.Module):
    """Simplified mock LM for demonstration."""

    def __init__(self, vocab_size=1000, hidden_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.transformer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=4, batch_first=True
        )
        self.lm_head = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        x = self.transformer(x)
        logits = self.lm_head(x)
        return logits


class MockValueHead(nn.Module):
    """Value function head."""

    def __init__(self, hidden_dim=128):
        super().__init__()
        self.value_head = nn.Linear(hidden_dim, 1)

    def forward(self, hidden_states):
        return self.value_head(hidden_states).squeeze(-1)


class MockRewardModel(nn.Module):
    """Simplified reward model."""

    def __init__(self, vocab_size=1000, hidden_dim=64):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(hidden_dim, 1)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        x = self.pool(x.transpose(1, 2)).squeeze(-1)
        return self.head(x).squeeze(-1)
In[24]:
Code
def run_ppo_training_step(
    policy,
    ref_policy,
    value_head,
    reward_model,
    prompts,
    responses,
    optimizer,
    config,
):
    """
    Execute one PPO training step.

    Args:
        policy: Current policy model
        ref_policy: Reference policy (frozen)
        value_head: Value function
        reward_model: Reward model
        prompts: Batch of prompt token ids
        responses: Batch of response token ids
        optimizer: Optimizer for policy and value head
        config: PPO configuration

    Returns:
        Dictionary of training metrics
    """
    batch_size = prompts.shape[0]

    # Concatenate prompts and responses
    full_sequences = torch.cat([prompts, responses], dim=1)

    # Get reward scores for complete responses
    with torch.no_grad():
        reward_scores = reward_model(full_sequences)

    # Compute log probs under current and reference policy
    policy_logits = policy(full_sequences)
    with torch.no_grad():
        ref_logits = ref_policy(full_sequences)

    # Extract response portion
    prompt_len = prompts.shape[1]
    response_logits = policy_logits[:, prompt_len - 1 : -1, :]
    ref_response_logits = ref_logits[:, prompt_len - 1 : -1, :]

    # Compute log probabilities
    log_probs_policy = F.log_softmax(response_logits, dim=-1)
    log_probs_ref = F.log_softmax(ref_response_logits, dim=-1)

    # Gather log probs for chosen tokens
    chosen_log_probs = torch.gather(
        log_probs_policy, dim=-1, index=responses.unsqueeze(-1)
    ).squeeze(-1)

    chosen_log_probs_ref = torch.gather(
        log_probs_ref, dim=-1, index=responses.unsqueeze(-1)
    ).squeeze(-1)

    # Store old log probs for ratio computation
    old_log_probs = chosen_log_probs.detach()

    # Compute values
    hidden_states = policy.embedding(full_sequences)
    hidden_states = policy.transformer(hidden_states)
    values = value_head(hidden_states[:, prompt_len:, :])

    # Compute KL and advantages
    kl = chosen_log_probs - chosen_log_probs_ref

    # Sparse rewards + KL shaping
    rewards = torch.zeros_like(chosen_log_probs)
    rewards[:, -1] = reward_scores
    shaped_rewards = rewards - config.kl_coef * kl

    # Simple advantage computation (for demonstration)
    returns = torch.zeros_like(shaped_rewards)
    running_return = torch.zeros(batch_size)
    for t in reversed(range(responses.shape[1])):
        running_return = shaped_rewards[:, t] + config.gamma * running_return
        returns[:, t] = running_return

    advantages = returns - values.detach()
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # PPO loss
    ratio = torch.exp(chosen_log_probs - old_log_probs)
    clipped_ratio = torch.clamp(
        ratio, 1 - config.clip_epsilon, 1 + config.clip_epsilon
    )
    policy_loss = -torch.min(
        ratio * advantages, clipped_ratio * advantages
    ).mean()

    # Value loss
    value_loss = F.mse_loss(values, returns)

    # Total loss
    total_loss = policy_loss + config.value_coef * value_loss

    # Optimization step
    optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(
        list(policy.parameters()) + list(value_head.parameters()), max_norm=1.0
    )
    optimizer.step()

    return {
        "policy_loss": policy_loss.item(),
        "value_loss": value_loss.item(),
        "mean_reward": reward_scores.mean().item(),
        "mean_kl": kl.mean().item(),
        "mean_ratio": ratio.mean().item(),
    }
In[25]:
Code
# Run a demonstration training loop
torch.manual_seed(42)

vocab_size = 1000
hidden_dim = 128

# Initialize models
policy = MockLanguageModel(vocab_size, hidden_dim)
ref_policy = MockLanguageModel(vocab_size, hidden_dim)
ref_policy.load_state_dict(policy.state_dict())  # Start from same weights
value_head = MockValueHead(hidden_dim)
reward_model = MockRewardModel(vocab_size)

# Freeze reference
for param in ref_policy.parameters():
    param.requires_grad = False

# Optimizer
optimizer = torch.optim.Adam(
    list(policy.parameters()) + list(value_head.parameters()), lr=1e-4
)

config = PPOConfig()

# Training loop
metrics_history = []
for step in range(10):
    # Generate mock batch
    batch_size = 4
    prompt_len = 10
    response_len = 15

    prompts = torch.randint(0, vocab_size, (batch_size, prompt_len))
    responses = torch.randint(0, vocab_size, (batch_size, response_len))

    metrics = run_ppo_training_step(
        policy,
        ref_policy,
        value_head,
        reward_model,
        prompts,
        responses,
        optimizer,
        config,
    )
    metrics_history.append(metrics)
Out[26]:
Visualization
Evolution of key training metrics during PPO optimization. The plots track policy and value function loss, the mean KL divergence from the reference model, and the stability of the policy ratio, illustrating how clipping maintains the update within a controlled region.
Evolution of key training metrics during PPO optimization. The plots track policy and value function loss, the mean KL divergence from the reference model, and the stability of the policy ratio, illustrating how clipping maintains the update within a controlled region.
Notebook output
Notebook output
Notebook output

The training metrics reveal several important dynamics that you should monitor during real RLHF runs. The policy loss oscillates as the model learns to balance reward maximization against the clipping constraint. The KL divergence tracks how far the policy drifts from the reference, a key quantity we want to keep bounded. If KL grows too large, the policy is moving into territory where the reward model may not be reliable. The probability ratio stays near 1.0 because the clipping mechanism prevents extreme updates, demonstrating that the trust region constraint is functioning as intended.

Key Parameters

The key parameters for PPO training are:

  • clip_epsilon: The clipping threshold (typically 0.1 or 0.2) that constrains the policy update. Smaller values produce more conservative updates.
  • kl_coef: Coefficient for the KL penalty term, controlling how closely the policy must stay to the reference model. This is perhaps the most important hyperparameter for preventing reward hacking.
  • value_coef: Weight for the value function loss in the total optimization objective. Balancing this against the policy loss affects how quickly the value function adapts.
  • gamma: Discount factor for future rewards. In RLHF, this is typically set to 1.0 since we care equally about all tokens in the response.
  • lam: The GAE smoothing parameter that balances bias and variance in advantage estimation. Values around 0.95 are standard.

Practical Considerations

Implementing PPO for large language models involves several practical challenges beyond the algorithmic core. These engineering concerns often dominate the difficulty of real-world deployments and require careful attention.

Memory management is paramount. During training, you must store activations for the policy model, reference model, and value model simultaneously. The reference model can be loaded in half precision or quantized to reduce memory footprint. Some implementations share the backbone between policy and value models, adding only a small value head. This sharing reduces memory requirements but couples the representations, which may affect optimization dynamics.

Batch construction requires careful thought. Unlike supervised learning where examples are independent, PPO batches consist of complete generation trajectories. Generation is inherently sequential, making large-batch collection time-consuming. Implementations typically generate multiple responses in parallel across many prompts, using efficient batched inference to maximize throughput. The batch must contain enough diversity to provide a stable estimate of the gradient.

Response generation during training uses sampling rather than greedy decoding. This exploration is essential for PPO to discover high-reward responses that might differ from the reference policy's preferred outputs. Temperature and other sampling parameters become training hyperparameters that affect the exploration-exploitation balance. Too low a temperature leads to insufficient exploration; too high a temperature produces incoherent responses that receive low rewards.

Advantage normalization stabilizes training significantly. Normalizing advantages to have zero mean and unit variance across each batch prevents any single trajectory from dominating the gradient. Without normalization, a few outlier responses with extreme advantages could destabilize training by producing large gradient updates.

In[27]:
Code
# Calculate memory requirements for a 7B model
param_count = 7e9  # 7 billion
bytes_fp16 = 2
bytes_fp32 = 4

# Memory in GB
policy_mem = (param_count * bytes_fp16) / 1e9
ref_mem = (param_count * bytes_fp16) / 1e9

# Value head (approx 25M params)
value_head_params = 25e6
value_mem_mb = (value_head_params * bytes_fp32) / 1e6

# Optimizer (Adam stores momentums + variances in FP32 usually)
opt_mem = (param_count * bytes_fp32) / 1e9

total_mem = policy_mem + ref_mem + (value_mem_mb / 1000) + opt_mem
Out[28]:
Console
Memory footprint comparison for a 7B parameter model:
-------------------------------------------------------
Component                 Precision    Memory         
-------------------------------------------------------
Policy model              FP16         ~14 GB         
Reference model           FP16         ~14 GB         
Value head                FP32         ~100 MB        
Optimizer states          FP32         ~28 GB         
-------------------------------------------------------
Total (naive)                          ~56 GB         
Out[29]:
Visualization
Estimated memory footprint for PPO training of a 7-billion parameter model. The requirement to simultaneously host the policy model, a frozen reference model, and the optimizer states (often in full precision) makes memory management a primary bottleneck in LLM alignment.
Estimated memory footprint for PPO training of a 7-billion parameter model. The requirement to simultaneously host the policy model, a frozen reference model, and the optimizer states (often in full precision) makes memory management a primary bottleneck in LLM alignment.

Optimization strategies:

  • Load reference in 8-bit: saves ~7 GB
  • Gradient checkpointing: reduces activation memory
  • LoRA on policy: reduces optimizer state memory

Limitations and Challenges

Applying PPO to language models, while effective, comes with significant challenges that you must navigate. Understanding these limitations helps us find better methods.

Sparse rewards make credit assignment difficult. When a complete response receives a reward, determining which tokens contributed positively and which detracted is imprecise at best. The advantage function provides some differentiation through the value function's predictions, but it operates through imperfect estimates. The value function can only learn what patterns in intermediate states correlate with final rewards; it cannot directly observe the causal relationships. This imprecision can lead to slow learning, especially when the reward depends on subtle properties of the text that emerge from specific word choices or phrasings.

High computational cost is another barrier. PPO requires generating complete responses during training, which is much slower than the teacher-forcing paradigm used in supervised learning. Each training step involves running the policy model to generate text, running it again to compute log probabilities, running the reference model for KL computation, running the value model for advantage estimation, and running the reward model for scoring. Because of the computational and memory requirements, PPO is much more expensive than supervised fine-tuning. A single PPO training run can cost 10 to 100 times more than the equivalent SFT training.

The sensitivity to hyperparameters creates reproducibility challenges. The clipping parameter, KL coefficient, learning rate, and GAE parameters all interact in complex ways. Settings that work well for one model or task may fail on another. Small changes to any hyperparameter can lead to training instability or suboptimal results. This sensitivity motivates the development of alternative approaches, which we will explore in upcoming chapters on DPO and its variants.

Finally, the reliance on a learned reward model introduces all the challenges we discussed in the Reward Hacking chapter. The reward model is a proxy for human preferences, and optimizing it too aggressively can lead to responses that score highly according to the model while being less useful or even harmful. The KL penalty mitigates but does not eliminate this risk. Balancing reward optimization and constraints is an active research area.

Summary

This chapter translated the PPO algorithm from its general reinforcement learning formulation to the specific setting of language model alignment. The key conceptual mappings are:

  • A language model serves as a stochastic policy, with states being the growing context and actions being vocabulary tokens
  • The action space is the vocabulary, typically containing 30,000 to 100,000 discrete options
  • Rewards are assigned sparsely, with the reward model's score appearing only at the final token
  • The KL divergence penalty transforms sparse rewards into dense signal while preventing policy collapse

The PPO objective for language models combines the clipped surrogate loss with KL-shaped rewards, creating a training signal that balances reward maximization against staying close to the reference distribution. This balance is crucial for stable training and for avoiding reward hacking.

In the next chapter on the RLHF Pipeline, we will see how PPO fits into the complete workflow that transforms a pretrained language model into an aligned assistant, including the data collection, reward model training, and iterative refinement stages that surround the PPO optimization we have studied here.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about applying PPO to language models.

Loading component...

Reference

BIBTEXAcademic
@misc{ppoforlanguagemodelsadaptingrltotextgeneration, author = {Michael Brenndoerfer}, title = {PPO for Language Models: Adapting RL to Text Generation}, year = {2025}, url = {https://mbrenndoerfer.com/writing/ppo-for-language-models-rlhf-policy-optimization}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-01-01} }
APAAcademic
Michael Brenndoerfer (2025). PPO for Language Models: Adapting RL to Text Generation. Retrieved from https://mbrenndoerfer.com/writing/ppo-for-language-models-rlhf-policy-optimization
MLAAcademic
Michael Brenndoerfer. "PPO for Language Models: Adapting RL to Text Generation." 2026. Web. today. <https://mbrenndoerfer.com/writing/ppo-for-language-models-rlhf-policy-optimization>.
CHICAGOAcademic
Michael Brenndoerfer. "PPO for Language Models: Adapting RL to Text Generation." Accessed today. https://mbrenndoerfer.com/writing/ppo-for-language-models-rlhf-policy-optimization.
HARVARDAcademic
Michael Brenndoerfer (2025) 'PPO for Language Models: Adapting RL to Text Generation'. Available at: https://mbrenndoerfer.com/writing/ppo-for-language-models-rlhf-policy-optimization (Accessed: today).
SimpleBasic
Michael Brenndoerfer (2025). PPO for Language Models: Adapting RL to Text Generation. https://mbrenndoerfer.com/writing/ppo-for-language-models-rlhf-policy-optimization