KL Divergence Penalty in RLHF: Theory & Implementation

Michael BrenndoerferDecember 30, 202543 min read

Learn how KL divergence prevents reward hacking in RLHF by keeping policies close to reference models. Covers theory, adaptive control, and PyTorch code.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

KL Divergence Penalty

In the previous chapters, we built up the RLHF pipeline: collecting human preferences, training reward models, and applying PPO to optimize language model policies. But there's a critical problem we've only briefly mentioned: reward hacking. Without constraints, a policy can discover bizarre outputs that score highly on the reward model while being clearly worse by any human measure. The KL divergence penalty is the mechanism that prevents this collapse, keeping the fine-tuned model anchored to the capabilities of its pre-trained foundation.

This chapter explores why KL divergence is the right tool for this job, how to compute it efficiently for autoregressive models, and how to set and adapt the KL coefficient to balance learning against stability.

Why KL Divergence?

Recall from the Reward Hacking chapter that reward models are imperfect approximations of human preferences. They're trained on a finite dataset and will inevitably assign high scores to some outputs that humans would reject. When we optimize a policy using PPO, we're searching over a vast space of possible behaviors, and the optimization process will find these reward model errors if given enough freedom. This creates a trade-off in RLHF: we must improve the model using the reward signal without trusting it unconditionally.

The fundamental insight is that the pre-trained language model already represents a strong prior over sensible, grammatical, informative text. It was trained on trillions of tokens of human-written content, absorbing patterns of coherent discourse, factual knowledge, and linguistic conventions. By constraining our policy to stay close to this reference distribution, we inherit its linguistic competence while still allowing targeted improvements in alignment. Think of the reference model as a trusted expert whose judgment we respect: we want to make adjustments based on new feedback, but we don't want to stray so far that we lose the foundational competence that made the model useful in the first place.

KL divergence measures how different one probability distribution is from another. For language models, this means measuring how much the policy's probability assignments over tokens differ from the reference model's. At each position in a sequence, both models assign probabilities to every token in the vocabulary. KL divergence captures the extent to which these probability distributions disagree, weighted by the probability mass the policy assigns. The key properties that make KL divergence ideal for this application are:

  • Non-negativity: DKL(PQ)0D_{KL}(P \| Q) \geq 0, with equality only when P = Q
  • Asymmetry: It penalizes the policy for assigning probability to tokens the reference doesn't expect
  • Decomposability: For autoregressive models, we can compute KL at each token position and sum

The non-negativity property ensures that the penalty term always works in one direction: it can only reduce the objective, never artificially inflate it. This mathematical guarantee means the optimization has a clear interpretation where reward pulls the policy toward better outputs while the KL penalty resists moving too far from the reference. The asymmetry property is particularly important for RLHF because it creates an asymmetric treatment of errors. If the policy assigns high probability to something the reference considers unlikely, the penalty is severe. But if the policy assigns low probability to something the reference considers likely, the penalty is more forgiving. This asymmetry encourages the policy to remain conservative, avoiding novel but potentially problematic outputs. Finally, decomposability allows us to compute and analyze the penalty at the level of individual tokens, providing fine-grained insight into where and how the policy is diverging.

Out[2]:
Visualization
Decomposition of KL divergence between policy (P) and reference (Q) distributions. The left panel compares probability assignments over a vocabulary, highlighting disagreements on tokens like 'good', 'bad', and 'helpful'. The right panel illustrates the per-token contribution to the total divergence, where positive values (green) indicate the policy assigns higher probability than the reference, while negative values (red) reduce the total divergence.
Decomposition of KL divergence between policy (P) and reference (Q) distributions. The left panel compares probability assignments over a vocabulary, highlighting disagreements on tokens like 'good', 'bad', and 'helpful'. The right panel illustrates the per-token contribution to the total divergence, where positive values (green) indicate the policy assigns higher probability than the reference, while negative values (red) reduce the total divergence.
Notebook output
KL Divergence

KL divergence (Kullback-Leibler divergence) quantifies the expected number of extra bits needed to encode samples from distribution P using a code optimized for distribution Q. In machine learning, it measures how one probability distribution differs from a reference distribution.

Mathematical Foundation

The mathematical definition of KL divergence is the basis for these methods. The formula captures a precise notion of distributional difference that translates directly into the computational procedures we use during training.

For two discrete probability distributions P and Q over the same set of outcomes, KL divergence is defined:

DKL(PQ)=xP(x)logP(x)Q(x)=ExP[logP(x)logQ(x)]\begin{aligned} D_{KL}(P \| Q) &= \sum_x P(x) \log \frac{P(x)}{Q(x)} \\ &= \mathbb{E}_{x \sim P}\left[\log P(x) - \log Q(x)\right] \end{aligned}

where:

  • DKL(PQ)D_{KL}(P \| Q): the Kullback-Leibler divergence from QQ to PP, representing the information lost when QQ is used to approximate PP
  • P(x)P(x): the probability of outcome xx under the first distribution (typically the policy we are optimizing)
  • Q(x)Q(x): the probability of outcome xx under the reference distribution (the baseline or prior)
  • ExP\mathbb{E}_{x \sim P}: the expectation calculated over samples drawn from PP, meaning we average the log-ratio over the events that actually happen under distribution PP

The first line presents the definition as a weighted sum over all possible outcomes, where each outcome's contribution is its probability under P times the log ratio of probabilities. The second line rewrites this as an expectation, which is the form we actually use in practice since we work with samples rather than complete distributions. This expectation-based form tells us something important: we only need to evaluate the log probability ratio for outcomes that actually occur under the policy, not for every possible outcome. This property is what makes KL divergence tractable for large vocabulary language models.

In the RLHF context, PP is the policy distribution πθ\pi_\theta and QQ is the reference distribution πref\pi_{\text{ref}} (typically the supervised fine-tuned model before RLHF). The policy is the distribution we're optimizing, and the reference is the distribution we want to stay close to. Every time we sample a token from the policy, we can compute its contribution to the KL divergence by looking at how much more or less likely that token was under the policy compared to the reference.

For autoregressive language models, the probability of a complete response y=(y1,y2,,yT)y = (y_1, y_2, \ldots, y_T) given prompt xx factorizes as:

π(yx)=t=1Tπ(ytx,y<t)\pi(y|x) = \prod_{t=1}^{T} \pi(y_t | x, y_{<t})

where:

  • π(yx)\pi(y|x): the probability of the complete sequence yy given prompt xx, computed as the product of conditional probabilities
  • TT: the total length of the generated sequence (number of tokens)
  • yty_t: the token generated at time step tt
  • y<ty_{<t}: the history of tokens generated before step tt, which serves as the context for the next prediction

This factorization is the defining characteristic of autoregressive models. The probability of an entire sequence equals the product of probabilities for each token, where each token's probability depends on the prompt and all previously generated tokens. This chain structure means that computing the probability of a complete sequence requires T forward passes through the model, one for each token position, though in practice we can compute all positions in parallel when the sequence is known.

Taking the log and using the chain rule of KL divergence, the KL divergence between the policy and reference over complete responses can be written as:

DKL(πθ(x)πref(x))=Eyπθ(x)[t=1Tlogπθ(ytx,y<t)πref(ytx,y<t)]D_{KL}(\pi_\theta(\cdot|x) \| \pi_{\text{ref}}(\cdot|x)) = \mathbb{E}_{y \sim \pi_\theta(\cdot|x)}\left[\sum_{t=1}^{T} \log \frac{\pi_\theta(y_t | x, y_{<t})}{\pi_{\text{ref}}(y_t | x, y_{<t})}\right]

where:

  • DKL(πθπref)D_{KL}(\pi_\theta \| \pi_{\text{ref}}): the KL divergence between the policy and reference distributions over complete sequences
  • Eyπθ\mathbb{E}_{y \sim \pi_\theta}: the expectation calculated over response sequences yy sampled from the policy, since we estimate KL using the model's own generations
  • t=1T\sum_{t=1}^{T}: the summation over all TT tokens, accumulating the divergence contribution from each step
  • logπθ()πref()\log \frac{\pi_\theta(\dots)}{\pi_{\text{ref}}(\dots)}: the log probability ratio at each time step, which measures how much more (or less) likely a token is under the policy compared to the reference

This decomposition is crucial: we can compute KL as a sum of per-token log probability ratios, evaluated along the trajectory sampled from the policy. The log of a product becomes a sum of logs, which transforms the sequence-level KL into a sum of token-level contributions. Each term in the sum measures how much the policy and reference disagree about the probability of a specific token given the context. When the policy assigns higher probability than the reference, the term is positive, contributing to the divergence. When the policy assigns lower probability, the term is negative, reducing the divergence. The total KL divergence aggregates these token-level disagreements across the entire sequence.

Per-Token KL Computation

Given a sampled response, the per-token KL contribution is simply:

KLt=logπθ(ytx,y<t)logπref(ytx,y<t)\text{KL}_t = \log \pi_\theta(y_t | x, y_{<t}) - \log \pi_{\text{ref}}(y_t | x, y_{<t})

where:

  • KLt\text{KL}_t: the contribution to the KL divergence at time step tt (positive if the policy favors the token more than the reference)
  • πθ(ytx,y<t)\pi_\theta(y_t | x, y_{<t}): the probability assigned to the token yty_t by the policy model given the context
  • πref(ytx,y<t)\pi_{\text{ref}}(y_t | x, y_{<t}): the probability assigned to the token yty_t by the reference model given the context

This formula reveals the elegant simplicity of per-token KL computation. For any token in a generated sequence, we simply subtract the reference model's log probability from the policy model's log probability. No summation over the vocabulary is required because we're only interested in the token that was actually generated. This is a direct consequence of the expectation-based formulation: since we're averaging over samples from the policy, we only need to evaluate the log ratio at the sampled outcomes.

Summing these over all tokens in the response gives:

KL(y)=t=1TKLt=t=1T[logπθ(ytx,y<t)logπref(ytx,y<t)]\begin{aligned} \text{KL}(y) &= \sum_{t=1}^{T} \text{KL}_t \\ &= \sum_{t=1}^{T} \left[\log \pi_\theta(y_t | x, y_{<t}) - \log \pi_{\text{ref}}(y_t | x, y_{<t})\right] \end{aligned}

where:

  • KL(y)\text{KL}(y): the approximate KL divergence for a single sampled response yy, calculated as the sum of per-token divergences
  • TT: the length of the response in tokens, over which the divergence accumulates

This is an unbiased estimator of the KL divergence under the policy distribution, since we're sampling trajectories from πθ\pi_\theta itself. The unbiasedness property means that if we average this quantity over many sampled responses, we converge to the true KL divergence between the policy and reference distributions. In practice, we compute this estimator for each response in a training batch and use the batch average as our estimate of the expected KL divergence.

Out[3]:
Visualization
Per-token KL computation for a generated sequence. The left panel compares log probabilities, showing where the policy (blue) diverges from the reference (orange). The right panel quantifies these divergences as KL contributions; the policy's strong preference for 'helpful' (token 12) generates a large positive contribution, driving the total divergence.
Per-token KL computation for a generated sequence. The left panel compares log probabilities, showing where the policy (blue) diverges from the reference (orange). The right panel quantifies these divergences as KL contributions; the policy's strong preference for 'helpful' (token 12) generates a large positive contribution, driving the total divergence.
Notebook output
KL Divergence

KL divergence (Kullback-Leibler divergence) quantifies the expected number of extra bits needed to encode samples from distribution P using a code optimized for distribution Q. In machine learning, it measures how one probability distribution differs from a reference distribution.

The KL-Constrained Objective

Having established how to compute KL divergence, we now examine how it enters the RLHF optimization objective. The KL term acts as a regularizer, pulling against the reward signal to prevent the policy from straying too far from its trusted starting point.

The RLHF objective with KL penalty is:

J(θ)=ExD[Eyπθ(x)[rϕ(x,y)]βDKL(πθ(x)πref(x))]\mathcal{J}(\theta) = \mathbb{E}_{x \sim \mathcal{D}}\left[\mathbb{E}_{y \sim \pi_\theta(\cdot|x)}[r_\phi(x, y)] - \beta \cdot D_{KL}(\pi_\theta(\cdot|x) \| \pi_{\text{ref}}(\cdot|x))\right]

where:

  • J(θ)\mathcal{J}(\theta): the objective function to be maximized, balancing reward maximization against drift from the reference
  • D\mathcal{D}: the dataset of prompts used for training
  • rϕ(x,y)r_\phi(x, y): the learned reward model score for prompt xx and response yy, providing the guidance signal
  • β\beta: the KL coefficient controlling the strength of the penalty (higher values force the policy closer to the reference)
  • πref\pi_{\text{ref}}: the frozen reference model distribution, serving as the anchor

The structure of this objective captures the core tradeoff in RLHF. The first term rewards the policy for generating responses that score highly according to the reward model. The second term penalizes the policy for diverging from the reference. The coefficient β determines the relative importance of these two competing forces. When β is large, the penalty dominates and the policy barely moves from the reference. When β is small, the reward signal dominates and the policy can move freely toward high-reward regions, potentially exploiting imperfections in the reward model.

This can be rewritten using the per-token KL decomposition:

J(θ)=Ex,y[rϕ(x,y)βt=1Tlogπθ(ytx,y<t)πref(ytx,y<t)]\mathcal{J}(\theta) = \mathbb{E}_{x, y}\left[r_\phi(x, y) - \beta \sum_{t=1}^{T} \log \frac{\pi_\theta(y_t | x, y_{<t})}{\pi_{\text{ref}}(y_t | x, y_{<t})}\right]

where:

  • Ex,y\mathbb{E}_{x, y}: the expectation over prompts and sampled responses
  • t=1T\sum_{t=1}^{T}: the sum of log probability ratios over the sequence, representing the total divergence for that trajectory

This formulation makes the objective directly computable from model outputs. For each sampled response, we obtain a reward from the reward model and compute the sum of log probability ratios from the policy and reference models. The combination of these quantities, averaged over the batch, gives us the objective value that we then maximize through gradient ascent.

Connection to Constrained Optimization

The KL-penalized objective is equivalent to the Lagrangian relaxation of a constrained optimization problem:

maxθE[rϕ(x,y)]subject toE[DKL(πθπref)]ϵ\max_\theta \mathbb{E}[r_\phi(x, y)] \quad \text{subject to} \quad \mathbb{E}[D_{KL}(\pi_\theta \| \pi_{\text{ref}})] \leq \epsilon

where:

  • maxθ\max_\theta: the maximization over policy parameters θ\theta
  • E\mathbb{E}: the expectation over the data distribution
  • ϵ\epsilon: the maximum allowable divergence (the "budget"), limiting how far the policy can drift

In this constrained view, we're asking for the policy that maximizes expected reward subject to a hard constraint on how much it can diverge from the reference. The constraint ε specifies a "KL budget" that the policy must respect. The penalized objective arises when we convert this constrained problem to an unconstrained one using Lagrangian relaxation.

The coefficient β\beta acts as a Lagrange multiplier. Larger β\beta corresponds to a tighter constraint (smaller ϵ\epsilon), forcing the policy to stay closer to the reference. This perspective helps us understand what we're actually optimizing: we want the highest-reward policy among those that remain within a KL "budget" of the reference. The correspondence between β and ε is not one-to-one in practice because the constraint is enforced softly rather than exactly, but the intuition remains valuable. Choosing β is analogous to choosing how much divergence we're willing to tolerate in exchange for reward improvement.

Reward Shaping Interpretation

Another way to view the KL term is as a per-token bonus added to the reward. This interpretation connects the KL penalty to the classical reinforcement learning concept of reward shaping, where the reward function is augmented with additional terms to guide learning.

We can rewrite the objective as:

J(θ)=Ex,y[t=1Tγtrt]\mathcal{J}(\theta) = \mathbb{E}_{x, y}\left[\sum_{t=1}^{T} \gamma^t \cdot r_t\right]

where:

  • γ\gamma: the discount factor (typically 1 in this context, treating all tokens as equally important)
  • rtr_t: the shaped reward at time step tt, incorporating both the environment reward and the KL penalty

The shaped reward at each step is defined as:

rt={βKLtif t<Trϕ(x,y)βKLTif t=Tr_t = \begin{cases} -\beta \cdot \text{KL}_t & \text{if } t < T \\ r_\phi(x, y) - \beta \cdot \text{KL}_T & \text{if } t = T \end{cases}

where:

  • KLt\text{KL}_t: the per-token KL penalty, acting as an instantaneous cost for deviating from the reference
  • rϕ(x,y)r_\phi(x, y): the final reward from the reward model, typically given only at the end of the sequence (t=Tt=T)

This interpretation shows that the policy receives negative reward proportional to its deviation from the reference at each token, plus the final reward model score at the end of generation. The per-token penalties encourage staying on-distribution throughout the generation, not just at the final output. Every token choice that differs from what the reference would have chosen incurs a cost, creating pressure to remain conservative at every step of generation.

This reward shaping perspective has practical implications for implementation. In PPO, we need to assign credit to individual actions (tokens) for the overall outcome. The shaped reward formulation provides a natural way to do this: each token receives the KL penalty immediately, while the reward model score is attributed to the final token. This decomposition helps the value function and advantage estimation work effectively, since part of the reward signal is available at every time step rather than only at the end.

Out[4]:
Visualization
Reward shaping interpretation of the KL penalty. Negative KL penalties (orange) are applied at every token step to penalize deviation from the reference, while the sparse reward model signal (blue) is added only at the final token. The total return integrates these instantaneous costs with the terminal reward, incentivizing the policy to remain on-distribution throughout generation.
Reward shaping interpretation of the KL penalty. Negative KL penalties (orange) are applied at every token step to penalize deviation from the reference, while the sparse reward model signal (blue) is added only at the final token. The total return integrates these instantaneous costs with the terminal reward, incentivizing the policy to remain on-distribution throughout generation.

KL Coefficient Selection

The KL coefficient β\beta is perhaps the most important hyperparameter in RLHF. It controls the fundamental tradeoff between:

  • Exploration/learning (low β\beta): The policy can move freely toward high-reward regions, enabling rapid improvement
  • Conservation/stability (high β\beta): The policy stays close to the reference, preserving language modeling capabilities

Effects of Different Coefficient Values

Too low (β<0.01\beta < 0.01):

  • The policy optimizes almost purely for reward
  • Susceptible to reward hacking and mode collapse
  • May produce repetitive, formulaic outputs that exploit reward model quirks
  • Loses diversity and coherence over time

Too high (β>1.0\beta > 1.0):

  • Learning becomes extremely slow
  • The policy barely moves from the reference
  • May never achieve meaningful alignment improvements
  • Effectively wastes computational resources

Typical range (β[0.01,0.2]\beta \in [0.01, 0.2]):

  • Most successful RLHF implementations use values in this range
  • InstructGPT used an initial value around 0.02
  • Anthropic's Constitutional AI work used values around 0.001-0.01
  • The optimal value depends on reward model quality and desired behavior change

Practical Guidelines for Selection

When selecting a KL coefficient, consider these factors:

  • Reward model confidence: If your reward model was trained on limited data or shows signs of miscalibration, use a higher β\beta to limit exploitation of its errors.

  • Magnitude of desired behavior change: For small adjustments (e.g., reducing specific harms), lower β\beta suffices. For significant capability changes, you may need to accept higher KL.

  • Response length: Longer responses accumulate more KL. If you're optimizing for verbose outputs, the effective per-token penalty is spread over more tokens, so you might need higher β\beta.

  • Training stability: If you observe reward increasing while output quality degrades (subjectively), increase β\beta to strengthen the constraint.

Out[5]:
Visualization
Impact of the KL coefficient $\beta$ on the optimization landscape. The left panel shows how increasing $\beta$ steepens the penalty slope, shifting the optimal KL divergence (dots) to the left. The right panel plots this relationship directly, demonstrating that optimal KL divergence decreases inverse-proportionally as $\beta$ increases.
Impact of the KL coefficient $\beta$ on the optimization landscape. The left panel shows how increasing $\beta$ steepens the penalty slope, shifting the optimal KL divergence (dots) to the left. The right panel plots this relationship directly, demonstrating that optimal KL divergence decreases inverse-proportionally as $\beta$ increases.
Notebook output

Adaptive KL Control

Rather than fixing β\beta throughout training, adaptive KL methods adjust the coefficient to maintain a target KL budget. This approach, introduced in the InstructGPT paper, provides more consistent training dynamics across different stages of optimization. The core insight is that the appropriate constraint strength changes as training progresses, and a fixed coefficient cannot account for these changing conditions.

The Target KL Approach

The idea is to specify a target KL divergence DtargetD_{\text{target}} that we want to maintain on average. If the current KL exceeds this target, we increase β\beta to pull the policy back. If KL is below target, we decrease β\beta to allow more exploration. This creates a feedback loop that stabilizes training by keeping the policy within a controlled distance from the reference.

The adaptation rule used in many implementations is:

βt+1={2βtif DKL>1.5Dtargetβt/2if DKL<Dtarget/1.5βtotherwise\beta_{t+1} = \begin{cases} 2 \cdot \beta_t & \text{if } D_{KL} > 1.5 \cdot D_{\text{target}} \\ \beta_t / 2 & \text{if } D_{KL} < D_{\text{target}} / 1.5 \\ \beta_t & \text{otherwise} \end{cases}

where:

  • βt+1\beta_{t+1}: the updated KL coefficient for the next step
  • βt\beta_t: the current KL coefficient
  • DKLD_{KL}: the observed average KL divergence in the current batch (the signal used for feedback)
  • DtargetD_{\text{target}}: the desired target KL divergence (the setpoint for the controller)

This multiplicative update creates a dead zone around the target where β\beta remains stable, preventing oscillation while still responding to significant deviations. The factor of 1.5 defines the boundaries of this dead zone: as long as the observed KL stays between Dtarget/1.5D_{\text{target}}/1.5 and 1.5Dtarget1.5 \cdot D_{\text{target}}, the coefficient remains unchanged. Only when KL drifts outside this range does the controller intervene with a multiplicative adjustment.

Why Adaptive KL Works

Adaptive KL addresses a fundamental challenge: the appropriate constraint strength changes during training. Early in RLHF, when the policy is close to the reference, small rewards can produce large gradient updates that rapidly increase KL. Later, when the policy has already diverged somewhat, the same learning rate produces smaller relative changes. A fixed β cannot account for these different regimes.

By targeting a consistent KL budget, adaptive control:

  • Prevents early-training instability from sudden KL spikes
  • Maintains learning signal late in training when fixed β\beta might over-constrain
  • Provides a more interpretable hyperparameter (target KL in nats rather than abstract coefficient)

The interpretability benefit is substantial. When using a fixed β, it's difficult to know in advance what KL divergence will result. Different prompts, response lengths, and training stages all affect the relationship between β and actual KL. With adaptive control, you directly specify the divergence budget you're comfortable with, and the algorithm finds the appropriate β to achieve it.

Alternative Adaptation Schemes

Some implementations use smoother adaptation:

βt+1=βt(1+αsign(DKLDtarget))\beta_{t+1} = \beta_t \cdot \left(1 + \alpha \cdot \text{sign}(D_{KL} - D_{\text{target}})\right)

where:

  • α\alpha: a small step size parameter (e.g., 0.1), controlling how aggressively β\beta changes
  • sign()\text{sign}(\cdot): the sign function (returns +1 if the term is positive, -1 if negative), determining the direction of the update

This approach prevents the jarring factor-of-2 jumps while still steering toward the target. The updates are smaller and more frequent, creating smoother β trajectories. However, this also means slower response to large deviations, which can be problematic if KL suddenly spikes due to a particularly influential batch.

Others use proportional control:

βt+1=βtexp(αDKLDtargetDtarget)\beta_{t+1} = \beta_t \cdot \exp\left(\alpha \cdot \frac{D_{KL} - D_{\text{target}}}{D_{\text{target}}}\right)

where:

  • exp()\exp(\cdot): the exponential function, ensuring the multiplier is always positive
  • DKLDtargetDtarget\frac{D_{KL} - D_{\text{target}}}{D_{\text{target}}}: the relative error from the target, scaling the update based on the magnitude of the deviation

This scales the adjustment magnitude by how far off-target the current KL is. Small deviations produce small adjustments, while large deviations produce large adjustments. The exponential ensures that β remains positive regardless of the error magnitude. This proportional approach offers a middle ground between the dead-zone method and the constant-step method, responding proportionally to the severity of the deviation.

Code Implementation

Let's implement KL divergence computation and adaptive control for RLHF training. We'll build components that integrate with the PPO training loop from the previous chapter. The implementation focuses on clarity and correctness, providing the building blocks that can be optimized for production use.

In[6]:
Code

Computing Per-Token KL Divergence

The core computation extracts log probabilities from both the policy and reference model, then computes their difference. This function forms the foundation of the KL penalty calculation, taking log probabilities that have already been extracted from model outputs and producing both per-token and per-sequence KL values.

In[7]:
Code
import torch
from typing import Optional, Tuple


def compute_token_kl(
    policy_logprobs: torch.Tensor,
    reference_logprobs: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute per-token and per-sequence KL divergence.

    Args:
        policy_logprobs: Log probs from policy for selected tokens (batch, seq_len)
        reference_logprobs: Log probs from reference for same tokens (batch, seq_len)
        attention_mask: Mask indicating valid tokens (batch, seq_len)

    Returns:
        per_token_kl: KL at each position (batch, seq_len)
        sequence_kl: Total KL per sequence (batch,)
    """
    # Per-token KL is simply the log ratio
    per_token_kl = policy_logprobs - reference_logprobs

    if attention_mask is not None:
        # Zero out KL for padding tokens
        per_token_kl = per_token_kl * attention_mask
        # Sum over valid tokens only
        sequence_kl = per_token_kl.sum(dim=-1)
    else:
        sequence_kl = per_token_kl.sum(dim=-1)

    return per_token_kl, sequence_kl

This computation is straightforward because we're using the sampled trajectory estimator. The log probability ratio for tokens actually generated equals the KL divergence in expectation. The attention mask handles variable-length sequences by zeroing out contributions from padding tokens, ensuring that the KL computation only considers actual content. This masking is essential when processing batches of sequences with different lengths, as padding tokens should not contribute to the divergence measure.

Extracting Log Probabilities from Model Outputs

To compute KL, we need log probabilities for the tokens that were actually generated. Here's how to extract them from model logits:

In[8]:
Code
import torch.nn.functional as F


def get_logprobs_for_tokens(
    logits: torch.Tensor, tokens: torch.Tensor
) -> torch.Tensor:
    """
    Extract log probabilities for specific tokens from logits.

    Args:
        logits: Model output logits (batch, seq_len, vocab_size)
        tokens: Token indices to get probs for (batch, seq_len)

    Returns:
        logprobs: Log probabilities for specified tokens (batch, seq_len)
    """
    # Convert logits to log probabilities
    log_probs = F.log_softmax(logits, dim=-1)

    # Gather log probs for the actual tokens
    # Shape: (batch, seq_len, 1) -> (batch, seq_len)
    token_logprobs = torch.gather(
        log_probs, dim=-1, index=tokens.unsqueeze(-1)
    ).squeeze(-1)

    return token_logprobs

The function first applies log_softmax to convert raw logits into log probabilities. This operation normalizes the logits so they represent a valid probability distribution at each position. The gather operation then selects the log probability corresponding to each actual token from the full vocabulary distribution. This is much more efficient than computing probabilities for all tokens when we only need one per position. The unsqueeze and squeeze operations handle the dimension manipulation required by PyTorch's gather function.

Adaptive KL Controller

The adaptive controller maintains the KL coefficient and adjusts it based on observed KL values. It implements the target KL approach described in the mathematical section, providing a stateful object that can be updated after each training batch.

In[9]:
Code
class AdaptiveKLController:
    """
    Adaptive KL coefficient controller that targets a specific KL value.

    Based on the approach from InstructGPT (Ouyang et al., 2022).
    """

    def __init__(
        self,
        init_kl_coef: float = 0.1,
        target_kl: float = 6.0,
        horizon: int = 10000,
    ):
        """
        Args:
            init_kl_coef: Initial value of beta
            target_kl: Target KL divergence per response
            horizon: Number of steps over which to adapt (for smooth version)
        """
        self.kl_coef = init_kl_coef
        self.target_kl = target_kl
        self.horizon = horizon

        # Track history for monitoring
        self.kl_history = []
        self.coef_history = []

    def update(self, current_kl: float) -> float:
        """
        Update KL coefficient based on current KL divergence.

        Args:
            current_kl: Mean KL divergence from current batch

        Returns:
            Updated KL coefficient
        """
        self.kl_history.append(current_kl)

        # Multiplicative update (InstructGPT style)
        if current_kl > 1.5 * self.target_kl:
            self.kl_coef *= 2.0
        elif current_kl < self.target_kl / 1.5:
            self.kl_coef /= 2.0

        # Clamp to reasonable range
        self.kl_coef = max(min(self.kl_coef, 10.0), 0.001)

        self.coef_history.append(self.kl_coef)
        return self.kl_coef

    def get_coef(self) -> float:
        """Return current KL coefficient."""
        return self.kl_coef

The controller maintains a history of observed KL values and coefficient updates, which is useful for monitoring training dynamics and diagnosing issues. The clamping bounds prevent the coefficient from reaching extreme values that could either halt learning entirely (too high) or provide no regularization (too low). These bounds can be adjusted based on the specific application and model scale.

Let's test the adaptive controller with simulated KL values:

In[10]:
Code
# Simulate training dynamics
np.random.seed(42)
initial_beta = 0.1
controller = AdaptiveKLController(init_kl_coef=initial_beta, target_kl=6.0)

# Simulate KL values that start high, then stabilize
simulated_kl = np.concatenate(
    [
        np.random.uniform(8, 12, 20),  # Early: KL too high
        np.random.uniform(5, 7, 30),  # Middle: near target
        np.random.uniform(2, 4, 20),  # Late: KL dropping
        np.random.uniform(5, 7, 30),  # Stabilized near target
    ]
)

for kl in simulated_kl:
    controller.update(kl)
Out[11]:
Console
Initial KL coefficient: 0.1
Target KL: 6.0

KL coefficient evolution:
  After high-KL phase (step 20): 10.0000
  After stable phase (step 50): 10.0000
  After low-KL phase (step 70): 0.0010
  Final value (step 100): 0.0010

The controller doubles β\beta when KL exceeds the target, pulling the policy back toward the reference, and halves it when KL drops too low, allowing more exploration. The simulation demonstrates how the controller responds to different training phases: increasing the coefficient during the high-KL early phase, maintaining stability during the near-target phase, decreasing during the low-KL phase, and stabilizing again as KL returns to the target range.

Visualizing Adaptive KL Dynamics

Out[12]:
Visualization
Two-panel plot showing KL divergence and adaptive coefficient over training steps.
Dynamics of the adaptive KL controller during simulated training. The left panel shows observed KL divergence fluctuating around the target (dashed red); when KL exits the stability zone (green), the controller triggers an update. The right panel tracks the coefficient $\beta$ responding to these excursions, increasing to curb high KL and decreasing to encourage exploration when KL is low.
Notebook output

The shaded region marks the "stable zone" where β\beta doesn't change. Outside this zone, the controller applies multiplicative updates to steer KL back toward target.

Complete KL Penalty Integration

Here's how KL computation integrates into a simplified RLHF training step. This function combines all the pieces we've developed, taking model outputs and producing the penalized objective along with diagnostic information.

In[13]:
Code
def compute_rlhf_objective(
    policy_logits: torch.Tensor,
    reference_logits: torch.Tensor,
    response_tokens: torch.Tensor,
    rewards: torch.Tensor,
    attention_mask: torch.Tensor,
    kl_coef: float,
) -> dict:
    """
    Compute the RLHF objective with KL penalty.

    Args:
        policy_logits: Logits from current policy (batch, seq_len, vocab)
        reference_logits: Logits from reference model (batch, seq_len, vocab)
        response_tokens: Generated token indices (batch, seq_len)
        rewards: Reward model scores (batch,)
        attention_mask: Valid token mask (batch, seq_len)
        kl_coef: KL penalty coefficient (beta)

    Returns:
        Dictionary with objective value and diagnostics
    """
    # Get log probabilities for generated tokens
    policy_logprobs = get_logprobs_for_tokens(policy_logits, response_tokens)
    reference_logprobs = get_logprobs_for_tokens(
        reference_logits, response_tokens
    )

    # Compute KL divergence
    per_token_kl, sequence_kl = compute_token_kl(
        policy_logprobs, reference_logprobs, attention_mask
    )

    # KL-penalized reward
    penalized_rewards = rewards - kl_coef * sequence_kl

    # Mean objective (to be maximized)
    objective = penalized_rewards.mean()

    return {
        "objective": objective,
        "mean_reward": rewards.mean().item(),
        "mean_kl": sequence_kl.mean().item(),
        "mean_penalized_reward": penalized_rewards.mean().item(),
        "kl_penalty": (kl_coef * sequence_kl.mean()).item(),
        "per_token_kl": per_token_kl,
    }

The function returns a dictionary containing both the objective value for optimization and diagnostic quantities for monitoring. The separation between raw reward and penalized reward helps track whether improvements come from higher reward model scores or reduced divergence. The per-token KL tensor is included for detailed analysis of where the policy diverges most from the reference.

Let's demonstrate with synthetic data:

In[14]:
Code
# Create synthetic data
batch_size, seq_len, vocab_size = 4, 20, 1000
torch.manual_seed(42)

# Synthetic logits (policy slightly different from reference)
reference_logits = torch.randn(batch_size, seq_len, vocab_size)
policy_logits = reference_logits + 0.5 * torch.randn_like(reference_logits)

# Synthetic tokens and rewards
response_tokens = torch.randint(0, vocab_size, (batch_size, seq_len))
rewards = torch.tensor([0.8, 0.3, -0.2, 0.5])
attention_mask = torch.ones(batch_size, seq_len)

# Compute objective with different KL coefficients
low_beta = 0.01
results_low_beta = compute_rlhf_objective(
    policy_logits,
    reference_logits,
    response_tokens,
    rewards,
    attention_mask,
    kl_coef=low_beta,
)

high_beta = 0.5
results_high_beta = compute_rlhf_objective(
    policy_logits,
    reference_logits,
    response_tokens,
    rewards,
    attention_mask,
    kl_coef=high_beta,
)
Out[15]:
Console
Effect of KL Coefficient on Objective
=============================================

Low β (0.01):
  Mean reward:          0.3500
  Mean KL:              -2.2833
  KL penalty:           -0.0228
  Penalized reward:     0.3728

High β (0.5):
  Mean reward:          0.3500
  Mean KL:              -2.2833
  KL penalty:           -1.1416
  Penalized reward:     1.4916

With the same underlying KL divergence, a higher β\beta produces a much larger penalty, significantly reducing the effective reward signal available for optimization. The comparison illustrates the dramatic impact of the coefficient choice: with low β, most of the reward signal passes through to guide learning, while with high β, the penalty dominates and the effective reward becomes negative despite positive raw reward scores.

Key Parameters

The key parameters for the KL divergence implementation are:

  • kl_coef (β\beta): The weight of the KL penalty term. Controls the trade-off between reward maximization and reference adherence.
  • target_kl: The desired KL divergence value (in nats) for adaptive controllers. Used to dynamically adjust β\beta.

Effects on Training Dynamics

The KL coefficient fundamentally shapes how RLHF training unfolds. Let's examine these dynamics through simulation.

Reward vs KL Tradeoff Frontier

During training, there's a Pareto frontier between reward maximization and KL minimization. Different β\beta values trace different paths along this frontier:

In[16]:
Code
def simulate_training_trajectory(
    beta: float,
    n_steps: int = 100,
    reward_scale: float = 1.0,
    noise_scale: float = 0.1,
) -> dict:
    """
    Simulate RLHF training trajectory for a given beta.

    This is a simplified model where:
    - Reward increases with distance from reference (up to a point)
    - KL increases proportionally with this distance
    - Higher beta limits how far the policy moves
    """
    np.random.seed(int(beta * 1000))

    kl_values = [0.0]
    reward_values = [0.0]

    # Effective learning rate decreases with beta
    effective_lr = 0.1 / (1 + beta)

    for step in range(n_steps):
        current_kl = kl_values[-1]

        # Reward model has peak at medium KL, degrades at high KL (reward hacking)
        true_quality = 2 * (1 - np.exp(-current_kl / 10)) - 0.1 * max(
            0, current_kl - 15
        )

        # Gradient pushes toward higher reward, KL penalty pushes back
        reward_gradient = 0.2 * np.exp(-current_kl / 10)
        kl_penalty_gradient = beta * 0.05

        # Net update to KL
        kl_update = effective_lr * (reward_gradient - kl_penalty_gradient)
        kl_update += noise_scale * np.random.randn()

        new_kl = max(0, current_kl + kl_update)
        new_reward = true_quality + noise_scale * np.random.randn()

        kl_values.append(new_kl)
        reward_values.append(new_reward)

    return {
        "kl": np.array(kl_values),
        "reward": np.array(reward_values),
        "beta": beta,
    }


# Simulate for different beta values
betas = [0.01, 0.05, 0.1, 0.3, 1.0]
trajectories = [simulate_training_trajectory(b) for b in betas]
Out[17]:
Visualization
Scatter plot showing reward versus KL divergence trajectories for five different beta values.
Training trajectories in reward-KL space for different $\beta$ values. Lower $\beta$ values (warmer colors) allow the policy to drift deep into the 'reward hacking' region (red dashed line) to maximize reward. Higher $\beta$ values (cooler colors) constrain the trajectory, keeping the policy closer to the reference at the cost of lower total reward.

The visualization illustrates the fundamental tradeoff. Low-β\beta policies (warm colors) quickly accumulate KL divergence and can reach high rewards, but risk entering regions where the reward model is unreliable. High-β\beta policies (cool colors) progress slowly but stay in well-calibrated regions.

KL Distribution Over Tokens

KL divergence isn't uniform across tokens. Some positions contribute much more than others:

In[18]:
Code
# Simulate realistic per-token KL patterns
np.random.seed(42)
seq_len = 50

# Different patterns for different response types
# Pattern 1: High KL at beginning (new phrasing)
pattern_1 = np.exp(-np.linspace(0, 3, seq_len)) * 2 + 0.1

# Pattern 2: High KL at end (novel conclusions)
pattern_2 = np.exp(np.linspace(-3, 0, seq_len)) * 1.5 + 0.1

# Pattern 3: Spiky (specific word choices)
pattern_3 = 0.2 * np.ones(seq_len)
spike_positions = [5, 12, 23, 35, 45]
for pos in spike_positions:
    pattern_3[pos] = np.random.uniform(1.5, 3.0)

# Add some noise
patterns = {
    "Reformulated opening": pattern_1 + 0.1 * np.random.randn(seq_len),
    "Novel conclusion": pattern_2 + 0.1 * np.random.randn(seq_len),
    "Specific word choices": pattern_3
    + 0.05 * np.abs(np.random.randn(seq_len)),
}
Out[19]:
Visualization
Three line plots showing different per-token KL divergence patterns across sequence positions.
Per-token KL divergence patterns for different types of policy modifications. 'Reformulated opening' concentrates divergence at the start, 'Novel conclusion' accumulates divergence at the end, and 'Specific word choices' shows isolated spikes where the policy selects different vocabulary than the reference.
Notebook output
Notebook output

Understanding where KL accumulates helps diagnose what the policy is learning. If KL concentrates at specific positions, the policy is making targeted modifications. If KL is high throughout, the policy is developing a significantly different generation style.

Comparison: Fixed vs Adaptive KL

Let's compare training stability between fixed and adaptive KL control:

In[20]:
Code
def simulate_training_with_controller(
    controller_type: str,
    init_beta: float = 0.1,
    target_kl: float = 6.0,
    n_steps: int = 200,
) -> dict:
    """
    Simulate training with fixed or adaptive KL control.
    """
    np.random.seed(42)

    beta = init_beta
    kl_values = []
    beta_values = []
    reward_values = []
    penalized_reward_values = []

    # Simulated policy state
    policy_drift = 0.0

    for step in range(n_steps):
        # Current KL depends on how far policy has drifted
        base_kl = 2 * policy_drift
        current_kl = base_kl + np.random.exponential(1.0)

        # Reward increases with drift, but reward model noise increases too
        reward = 0.5 * np.log1p(policy_drift) + 0.2 * np.random.randn()
        penalized_reward = reward - beta * current_kl

        kl_values.append(current_kl)
        beta_values.append(beta)
        reward_values.append(reward)
        penalized_reward_values.append(penalized_reward)

        # Policy update: gradient step in direction of penalized reward
        policy_update = 0.05 * (1 - beta * 0.5)
        policy_drift = max(0, policy_drift + policy_update)

        # Update beta for adaptive controller
        if controller_type == "adaptive":
            if current_kl > 1.5 * target_kl:
                beta = min(beta * 2, 10.0)
            elif current_kl < target_kl / 1.5:
                beta = max(beta / 2, 0.001)

    return {
        "kl": np.array(kl_values),
        "beta": np.array(beta_values),
        "reward": np.array(reward_values),
        "penalized_reward": np.array(penalized_reward_values),
        "controller": controller_type,
    }


fixed_results = simulate_training_with_controller("fixed")
adaptive_results = simulate_training_with_controller("adaptive")
Out[21]:
Visualization
Four-panel comparison showing KL divergence, beta coefficient, reward, and penalized reward over training steps for fixed and adaptive controllers.
Training stability comparison between fixed (orange) and adaptive (blue) KL coefficients across a 2x2 grid. The adaptive controller (blue) maintains KL divergence near the target (top-left) by dynamically adjusting $\beta$ (top-right), resulting in stable penalized rewards (bottom-right). In contrast, the fixed controller (orange) allows KL to drift, causing the penalized reward to degrade over time.
Notebook output
Notebook output
Notebook output

The adaptive controller maintains consistent KL divergence throughout training by increasing β\beta when the policy drifts too far. This results in more stable penalized rewards compared to the fixed controller, where the growing KL penalty eventually dominates the objective.

Limitations and Practical Considerations

While KL divergence is the standard constraint for RLHF, it has several limitations worth understanding.

Choice of reference model matters: The KL penalty anchors the policy to a specific reference, typically the SFT model. If the SFT model has problematic behaviors, the KL penalty resists correcting them. Conversely, if the SFT model is highly capable, the constraint preserves that capability. This makes the quality of supervised fine-tuning crucial for RLHF success.

KL is a distributional constraint, not a behavioral one: KL divergence measures similarity between probability distributions, not between actual behaviors. Two policies could have small KL divergence but generate quite different outputs when sampled, especially for long sequences where small per-token differences compound. Conversely, policies with large KL might produce similar outputs if the probability differences are in low-probability regions.

Asymmetry creates specific biases: The KL divergence DKL(πθπref)D_{KL}(\pi_\theta \| \pi_{\text{ref}}) penalizes the policy more heavily for assigning low probability to tokens the reference model assigns high probability to. This means the policy is discouraged from becoming more deterministic than the reference, which can be desirable (maintaining diversity) or undesirable (preventing confident correct answers).

Computational overhead: Computing KL requires running both the policy and reference model on every training example. For large models, this roughly doubles the forward pass compute. Some implementations amortize this by caching reference model log probabilities, but this requires additional memory.

Looking ahead, methods like Direct Preference Optimization (DPO), covered in the next chapter, incorporate the KL constraint implicitly in their objective. This eliminates the need for explicit KL computation during training while achieving similar regularization effects, representing an elegant alternative to the explicit penalty approach we've examined here.

Summary

The KL divergence penalty is essential for stable RLHF training. It prevents reward hacking by keeping the policy close to a trusted reference model, preserving the language capabilities acquired during pre-training while allowing targeted improvements in alignment.

Key takeaways from this chapter:

  • KL divergence measures distribution difference: For autoregressive models, it decomposes into a sum of per-token log probability ratios, enabling efficient computation along sampled trajectories.

  • The coefficient β controls the reward-constraint tradeoff: Low β allows rapid learning but risks exploitation of reward model errors. High β maintains stability but slows improvement. Typical values range from 0.01 to 0.2.

  • Adaptive KL control maintains consistent constraints: By adjusting β to target a specific KL budget, adaptive methods provide stable training dynamics throughout optimization, preventing both early-training instability and late-training stagnation.

  • Per-token KL reveals what the policy learns: Analyzing where KL accumulates in generated sequences helps diagnose whether the policy is making targeted modifications or developing a fundamentally different generation style.

The KL penalty represents one approach to constraining policy optimization. As we'll see in the following chapters on DPO, there are alternative formulations that achieve similar goals through different mechanisms, offering tradeoffs in implementation complexity, computational efficiency, and training stability.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about KL divergence penalties in RLHF.

Loading component...

Reference

BIBTEXAcademic
@misc{kldivergencepenaltyinrlhftheoryimplementation, author = {Michael Brenndoerfer}, title = {KL Divergence Penalty in RLHF: Theory & Implementation}, year = {2025}, url = {https://mbrenndoerfer.com/writing/kl-divergence-penalty-rlhf-training}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-01-01} }
APAAcademic
Michael Brenndoerfer (2025). KL Divergence Penalty in RLHF: Theory & Implementation. Retrieved from https://mbrenndoerfer.com/writing/kl-divergence-penalty-rlhf-training
MLAAcademic
Michael Brenndoerfer. "KL Divergence Penalty in RLHF: Theory & Implementation." 2026. Web. today. <https://mbrenndoerfer.com/writing/kl-divergence-penalty-rlhf-training>.
CHICAGOAcademic
Michael Brenndoerfer. "KL Divergence Penalty in RLHF: Theory & Implementation." Accessed today. https://mbrenndoerfer.com/writing/kl-divergence-penalty-rlhf-training.
HARVARDAcademic
Michael Brenndoerfer (2025) 'KL Divergence Penalty in RLHF: Theory & Implementation'. Available at: https://mbrenndoerfer.com/writing/kl-divergence-penalty-rlhf-training (Accessed: today).
SimpleBasic
Michael Brenndoerfer (2025). KL Divergence Penalty in RLHF: Theory & Implementation. https://mbrenndoerfer.com/writing/kl-divergence-penalty-rlhf-training