Search

Search articles

NTK-aware Scaling: Extending Context Length in LLMs

Michael BrenndoerferUpdated July 1, 202533 min read

Learn how NTK-aware scaling extends transformer context windows by preserving high-frequency position information while scaling low frequencies for longer sequences.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

NTK-aware Scaling

Position Interpolation, as we saw in the previous chapter, extends context length by simply scaling down all rotation frequencies. While this works, it comes with a hidden cost: by compressing all frequencies equally, we lose the fine-grained positional distinctions that high-frequency components provide. Consider a model trained with 4,096 tokens. Position Interpolation at 8x scale compresses everything to fit 32,768 positions, but now adjacent tokens that were once easily distinguishable become nearly identical in their high-frequency dimensions.

NTK-aware scaling takes a more surgical approach. Instead of treating all frequencies equally, it recognizes that different frequency bands serve different purposes. High frequencies distinguish nearby tokens, while low frequencies capture long-range structure. By scaling frequencies non-uniformly, NTK-aware methods preserve local precision while still extending the effective context window.

This chapter develops the intuition behind NTK-aware scaling, derives the mathematical formula, and implements both static and dynamic variants. By the end, you'll understand why this approach often outperforms linear interpolation, especially for tasks requiring precise local attention.

The Problem with Uniform Scaling

To understand why uniform scaling falls short, we need to revisit how RoPE frequencies work. Recall that RoPE assigns each dimension pair ii a base frequency that determines how fast the rotation angle changes with position:

θi=1b2i/d\theta_i = \frac{1}{b^{2i/d}}

where:

  • θi\theta_i: the base rotation frequency (in radians per position) for dimension pair ii
  • bb: the base constant, typically 10000, which controls the overall frequency range
  • ii: the dimension pair index, ranging from 0 to d/21d/2 - 1
  • dd: the total embedding dimension (must be even since we work in pairs)
  • 2i/d2i/d: the exponent that creates a geometric progression of frequencies across dimensions

The first dimension pair (i=0i = 0) has frequency θ0=1/100000=1\theta_0 = 1/10000^0 = 1, meaning it rotates by 1 radian per position and completes a full rotation every 2π6.282\pi \approx 6.28 positions. The last pair (i=d/21i = d/2 - 1) has a much smaller frequency, completing a rotation only after thousands of positions. This spread creates a multi-scale representation:

  • High-frequency pairs (ii near 0): Distinguish nearby tokens with precision. Tokens at positions 5 and 6 look very different.
  • Low-frequency pairs (ii near d/2d/2): Capture coarse position information. Tokens far apart have noticeably different rotations.

Position Interpolation addresses context extension by scaling all frequencies uniformly:

θi=θis\theta_i' = \frac{\theta_i}{s}

where:

  • θi\theta_i': the scaled frequency for dimension pair ii after Position Interpolation
  • θi\theta_i: the original frequency for dimension pair ii
  • ss: the context extension factor, computed as s=Ltarget/Ltrains = L_{\text{target}} / L_{\text{train}} (target length divided by training length)

If we want to extend from 4,096 to 32,768 tokens, s=32768/4096=8s = 32768 / 4096 = 8. Every frequency gets divided by 8. The high-frequency pair that previously rotated 1 radian per position now rotates only 1/8=0.1251/8 = 0.125 radians. Adjacent tokens, which were once clearly distinguishable, become nearly indistinguishable in these dimensions.

Let's visualize this problem:

Out[3]:
Visualization
Line plot showing rotation angles increasing with position for original and scaled RoPE.
Rotation angles at each position for the high-frequency dimension (i=0). Original RoPE increases steeply while Position Interpolation compresses to a shallow slope.
Bar chart comparing angular separation between adjacent tokens.
Angular separation between adjacent tokens. Position Interpolation with 8x scaling reduces separation from 1 radian to 0.125 radians.

The compression is dramatic. With Position Interpolation at 8x scale, adjacent tokens are separated by only 0.125 radians in the highest-frequency dimension, compared to 1 radian originally. This means the model has far less "room" to distinguish nearby positions, potentially harming tasks that require precise local attention.

From Intuition to Formula: The NTK-aware Approach

Now that we've seen the problem with uniform scaling, let's develop a solution from first principles. The journey from intuition to a working formula requires answering three questions: What property do we want? How can we achieve it mathematically? And does the result match our expectations?

Neural Tangent Kernel (NTK)

The Neural Tangent Kernel describes how neural networks behave during training in the infinite-width limit. In this context, "NTK-aware" refers to preserving the high-frequency components that the network needs to learn fine-grained distinctions between nearby positions.

The Core Insight: Frequency Bands Serve Different Purposes

The key realization is that not all frequencies are equal in importance. Think of RoPE's multi-frequency design like a ruler with both centimeter and millimeter markings. The millimeter marks (high frequencies) give you precision for small measurements, while the centimeter marks (low frequencies) help you measure longer distances quickly. Position Interpolation is like shrinking the entire ruler by 8x, making even the millimeter marks too small to read. What we really want is to keep the millimeter precision while adjusting the centimeter scale.

Translating this intuition into concrete requirements:

  1. Preserve high frequencies: The fastest-rotating dimension pairs (small ii) should remain unchanged. These are our "millimeter marks" for local position distinctions.

  2. Scale low frequencies: The slowest-rotating dimension pairs (large ii) can be compressed by the full factor ss. These are our "centimeter marks" that need adjustment for longer contexts.

  3. Smooth transition: Intermediate dimensions should scale gradually between these extremes, not abruptly jump.

Designing the Solution: Modify the Base

How can we achieve dimension-dependent scaling? The original RoPE frequency formula is θi=1/b2i/d\theta_i = 1/b^{2i/d}. Position Interpolation modifies the output by dividing: θi=θi/s\theta_i' = \theta_i / s. This gives uniform scaling because every frequency is divided by the same constant.

Instead, we'll modify the input: the base bb. If we replace bb with a larger base bb', all frequencies decrease (slower rotation). The key insight is that the exponent 2i/d2i/d causes this decrease to affect different dimensions differently:

  • When i=0i = 0: θ0=1/(b)0=1\theta_0' = 1/(b')^0 = 1. No matter what bb' is, the highest frequency remains 1.
  • When ii is large: θi=1/(b)2i/d\theta_i' = 1/(b')^{2i/d} shrinks more because the exponent is larger.

This is exactly the property we want! By choosing the right bb', we can leave high frequencies untouched while scaling low frequencies by whatever factor we need.

Out[4]:
Visualization
2D scatter plot showing token positions rotated by high-frequency dimension under different scaling methods.
High-frequency dimension (i=0): Original RoPE rotates by 1 radian per position, creating widely spaced points. Position Interpolation compresses to 0.125 rad/pos, making adjacent positions nearly overlap. NTK-aware preserves the original spacing.
2D scatter plot showing token positions rotated by low-frequency dimension under different scaling methods.
Low-frequency dimension (i=31): All methods produce similar results here. NTK-aware and Position Interpolation both apply full compression to low-frequency dimensions.

This geometric view makes the difference concrete. In the high-frequency dimension (left), original RoPE spreads positions around the circle, allowing the model to distinguish them easily. Position Interpolation bunches them together near the starting point, reducing distinguishability. NTK-aware scaling preserves the original spread. In the low-frequency dimension (right), all methods behave similarly since both Position Interpolation and NTK-aware apply full compression to these slow-rotating dimensions.

Deriving the Formula

Let's work backward from our requirements to find the exact formula for bb'. We want a scaling function where:

  • scale0=1\text{scale}_0 = 1 (no scaling at highest frequency)
  • scaled/21=s\text{scale}_{d/2-1} = s (full scaling at lowest frequency)

Let b=bαb' = b \cdot \alpha for some multiplier α\alpha we need to determine. Substituting into the frequency formula:

θi=1(bα)2i/d\theta_i' = \frac{1}{(b \cdot \alpha)^{2i/d}}

Using the exponent rule (ab)n=anbn(ab)^n = a^n \cdot b^n, we can separate this:

θi=1b2i/d1α2i/d=θiα2i/d\theta_i' = \frac{1}{b^{2i/d}} \cdot \frac{1}{\alpha^{2i/d}} = \theta_i \cdot \alpha^{-2i/d}

The effective scaling factor for dimension ii becomes:

scalei=θiθi=α2i/d\text{scale}_i = \frac{\theta_i}{\theta_i'} = \alpha^{2i/d}

Now we apply our constraints. At i=0i = 0, we need scale0=1\text{scale}_0 = 1:

scale0=α20/d=α0=1(as required)\text{scale}_0 = \alpha^{2 \cdot 0/d} = \alpha^0 = 1 \quad \text{(as required)}

This is automatically satisfied for any α>0\alpha > 0, since any number raised to zero equals one. Our first requirement is met regardless of what α\alpha we choose.

At i=d/21i = d/2 - 1 (the last dimension pair), we need scaled/21=s\text{scale}_{d/2-1} = s:

scaled/21=α2(d/21)/d=s\text{scale}_{d/2-1} = \alpha^{2(d/2 - 1)/d} = s

Simplifying the exponent step by step:

α(d2)/d=s\alpha^{(d - 2)/d} = s

To solve for α\alpha, we raise both sides to the power d/(d2)d/(d-2):

α=sd/(d2)\alpha = s^{d/(d-2)}

This gives us the complete NTK-aware formula:

b=bsdd2b' = b \cdot s^{\frac{d}{d - 2}}

where:

  • bb': the modified base for NTK-aware scaling
  • bb: the original base (typically 10000)
  • ss: the context extension factor (e.g., s=8s = 8 to extend from 4k to 32k tokens)
  • dd: the embedding dimension
  • d/(d2)d/(d-2): an exponent slightly greater than 1 (for d=64d = 64, this equals 64/621.03264/62 \approx 1.032)

Understanding the Resulting Frequencies

Substituting the new base into the frequency formula reveals how each dimension is affected:

θi=1(b)2i/d=1(bsd/(d2))2i/d=θis2i/(d2)\theta_i' = \frac{1}{(b')^{2i/d}} = \frac{1}{(b \cdot s^{d/(d-2)})^{2i/d}} = \theta_i \cdot s^{-2i/(d-2)}

The effective scaling factor for each dimension is:

scalei=s2id2\text{scale}_i = s^{\frac{2i}{d - 2}}

Let's verify this matches our design goals:

  • When i=0i = 0 (highest frequency): scale0=s0=1\text{scale}_0 = s^0 = 1. The highest frequencies are not scaled at all. ✓
  • When i=d/21i = d/2 - 1 (lowest frequency): scaled/21=s(d2)/(d2)=s\text{scale}_{d/2-1} = s^{(d-2)/(d-2)} = s. The lowest frequencies are scaled by the full factor. ✓
  • Intermediate dimensions: Scaling increases smoothly from 1 to ss as ii increases. ✓

This is precisely what we wanted: high frequencies preserved, low frequencies scaled, smooth transition between.

A Worked Example

Let's make this concrete with typical values: d=64d = 64 dimensions, s=8s = 8 extension factor (4k to 32k tokens).

First, compute the NTK exponent:

dd2=64621.032\frac{d}{d-2} = \frac{64}{62} \approx 1.032

The modified base becomes:

b=1000081.032=100008.7287,200b' = 10000 \cdot 8^{1.032} = 10000 \cdot 8.72 \approx 87,200

Now we can compute scaling factors for specific dimensions:

NTK-aware scaling factors for, . High-frequency dimensions (small ) are preserved while low-frequency dimensions receive full scaling.
Dimension iiExponent 2i/(d2)2i/(d-2)scalei=8exp\text{scale}_i = 8^{\text{exp}}Interpretation
001.00No scaling (preserved)
80.2581.72Slight scaling
160.5162.95Moderate scaling
240.7745.06Significant scaling
311.008.00Full scaling

The gradient is smooth and continuous. High-frequency pairs (small ii) stay close to their original values, while low-frequency pairs (large ii) are compressed to fit the extended context.

Out[5]:
Visualization
Line plot showing uniform scaling for Position Interpolation and progressive scaling for NTK-aware method across dimension indices.
Scaling factors across dimension pairs for Position Interpolation vs NTK-aware scaling. Position Interpolation applies uniform scaling (flat line), while NTK-aware scaling preserves high frequencies (left) and only compresses low frequencies (right).

The visualization confirms our mathematical analysis. The NTK-aware curve (squares) starts at 1 for the highest-frequency pairs and smoothly increases to the full scale factor of 8 for the lowest-frequency pairs. Meanwhile, Position Interpolation (circles) maintains a flat line at 8, treating all dimensions identically.

Implementation

With the formula derived and verified, let's translate it into code. We'll build the implementation in layers, starting with the core frequency computation and progressively adding the full RoPE transformation.

Computing Frequencies

The foundation is a function that computes RoPE frequencies for any base value. This same function serves both the original frequencies and NTK-aware frequencies, since the only difference is the base we pass in:

In[6]:
Code
def compute_rope_frequencies(d, base=10000):
    """Compute original RoPE frequencies for d/2 dimension pairs."""
    dim_pairs = np.arange(d // 2)
    frequencies = 1.0 / (base ** (2 * dim_pairs / d))
    return frequencies

Position Interpolation simply divides all frequencies by the scale factor:

In[7]:
Code
def position_interpolation_frequencies(d, base=10000, scale=1.0):
    """Compute Position Interpolation frequencies (uniform scaling)."""
    original = compute_rope_frequencies(d, base)
    return original / scale

NTK-aware scaling modifies the base according to our derived formula, then computes frequencies using the modified base:

In[8]:
Code
def ntk_aware_frequencies(d, base=10000, scale=1.0):
    """Compute NTK-aware scaled frequencies (non-uniform scaling)."""
    # Modify the base according to NTK formula: b' = b * s^(d/(d-2))
    ntk_base = base * (scale ** (d / (d - 2)))
    return compute_rope_frequencies(d, ntk_base)

Comparing the Approaches

Let's compute frequencies for all three methods and compare them side by side:

In[9]:
Code
# Compare the two approaches
d = 64
base = 10000
scale = 8  # Extend context by 8x

freq_original = compute_rope_frequencies(d, base)
freq_pi = position_interpolation_frequencies(d, base, scale)
freq_ntk = ntk_aware_frequencies(d, base, scale)
Out[10]:
Console
Context extension factor: 8x
Embedding dimension: 64

Frequency comparison (first 5 dimension pairs):
 Dim     Original           PI          NTK   PI ratio  NTK ratio
--------------------------------------------------------------
   0     1.000000     0.125000     1.000000       8.00       1.00
   1     0.749894     0.093737     0.701242       8.00       1.07
   2     0.562341     0.070293     0.491741       8.00       1.14
   3     0.421697     0.052712     0.344829       8.00       1.22
   4     0.316228     0.039528     0.241809       8.00       1.31

Frequency comparison (last 5 dimension pairs):
 Dim     Original           PI          NTK   PI ratio  NTK ratio
--------------------------------------------------------------
  27     0.000422     0.000053     0.000069       8.00       6.12
  28     0.000316     0.000040     0.000048       8.00       6.54
  29     0.000237     0.000030     0.000034       8.00       7.00
  30     0.000178     0.000022     0.000024       8.00       7.48
  31     0.000133     0.000017     0.000017       8.00       8.00

The numbers confirm what we derived mathematically. Look at the "PI ratio" column: it's a constant 8.00 across all dimensions, reflecting Position Interpolation's uniform scaling. Now look at "NTK ratio": it starts near 1.00 for dimension 0 and gradually increases toward 8.00 for the highest dimensions. This is the dimension-dependent scaling in action.

Visualizing the Frequency Spectrum

A log-scale plot reveals the full picture across all 32 dimension pairs:

Out[11]:
Visualization
Line plot comparing angular separation between adjacent tokens for original, Position Interpolation, and NTK-aware scaling across dimension pairs.
Angular separation between adjacent tokens across all dimension pairs. NTK-aware scaling preserves the original separation for high-frequency dimensions (left side), while Position Interpolation compresses all dimensions uniformly.

On the log scale, observe how NTK-aware frequencies (triangles) track the original (circles) closely for low dimension indices, where high frequencies live. As we move right toward higher dimension indices (lower frequencies), the NTK curve gradually diverges to match Position Interpolation (squares). This is exactly the "preserve high frequencies, scale low frequencies" behavior we designed.

Rotation Angle Heatmaps

To visualize how different scaling methods affect the rotation patterns, let's create heatmaps showing the rotation angles across positions and dimension pairs:

Out[12]:
Visualization
Heatmap of original rotation angles across 20 positions and 32 dimension pairs.
Original RoPE rotation angles. High-frequency dimensions (bottom) show rapid variation, while low-frequency dimensions (top) change slowly.
Heatmap of Position Interpolation rotation angles showing uniform compression.
Position Interpolation uniformly compresses all rotation angles, reducing variation across all dimensions equally.
Heatmap of NTK-aware rotation angles showing preserved high frequencies.
NTK-aware scaling preserves high-frequency patterns (bottom) while compressing only low-frequency dimensions (top).

The heatmaps reveal the key difference between methods. In the original (left), the bottom rows (high-frequency dimensions) show rapid color cycling as position increases, while top rows (low-frequency dimensions) change slowly. Position Interpolation (center) uniformly slows down all cycling. NTK-aware scaling (right) preserves the rapid cycling in the bottom rows while slowing only the top rows. This visual confirms that NTK-aware scaling selectively modifies frequencies based on dimension.

Applying RoPE with NTK-aware Scaling

Now that we can compute the frequencies, let's implement the complete RoPE transformation and measure its effect on token distinguishability.

The RoPE Transformation

RoPE works by rotating each pair of embedding dimensions by a position-dependent angle. The rotation angle for position mm and dimension pair ii is simply mθim \cdot \theta_i, where θi\theta_i is the frequency we computed above:

In[13]:
Code
def apply_rope(x, frequencies):
    """Apply RoPE to a sequence of vectors.

    Args:
        x: Input vectors of shape (seq_len, d)
        frequencies: Base frequencies for each dimension pair, shape (d/2,)

    Returns:
        Rotated vectors of same shape as x
    """
    seq_len, d = x.shape
    positions = np.arange(seq_len)

    # Compute rotation angles: (seq_len, d/2)
    # Each position m gets angle m * theta_i for each dimension pair i
    angles = np.outer(positions, frequencies)

    # Split input into pairs for 2D rotation
    x_pairs = x.reshape(seq_len, -1, 2)  # (seq_len, d/2, 2)

    # Precompute sin and cos for efficiency
    cos_angles = np.cos(angles)  # (seq_len, d/2)
    sin_angles = np.sin(angles)

    # Apply 2D rotation: [x, y] -> [x*cos - y*sin, x*sin + y*cos]
    x_rotated = np.zeros_like(x_pairs)
    x_rotated[:, :, 0] = (
        x_pairs[:, :, 0] * cos_angles - x_pairs[:, :, 1] * sin_angles
    )
    x_rotated[:, :, 1] = (
        x_pairs[:, :, 0] * sin_angles + x_pairs[:, :, 1] * cos_angles
    )

    return x_rotated.reshape(seq_len, d)

Measuring Token Distinguishability

The ultimate test of our scaling method is whether adjacent tokens remain distinguishable after rotation. If tokens at positions mm and m+1m+1 become too similar, the model loses its ability to attend precisely to nearby positions.

We'll measure this using cosine similarity between adjacent token embeddings. Lower similarity means more distinguishable tokens:

In[14]:
Code
# Create sample embeddings for testing
np.random.seed(42)
seq_len = 10
d = 64
embeddings = np.random.randn(seq_len, d) * 0.1

# Apply RoPE with different scaling methods
rotated_original = apply_rope(embeddings, freq_original)
rotated_pi = apply_rope(embeddings, freq_pi)
rotated_ntk = apply_rope(embeddings, freq_ntk)
In[15]:
Code
# Compute cosine similarity between adjacent tokens
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


# Calculate similarities for each method
similarities = {"original": [], "pi": [], "ntk": []}

for i in range(seq_len - 1):
    similarities["original"].append(
        cosine_similarity(rotated_original[i], rotated_original[i + 1])
    )
    similarities["pi"].append(
        cosine_similarity(rotated_pi[i], rotated_pi[i + 1])
    )
    similarities["ntk"].append(
        cosine_similarity(rotated_ntk[i], rotated_ntk[i + 1])
    )
Out[16]:
Console
Cosine similarity between adjacent token embeddings:
   Positions   Original         PI        NTK
--------------------------------------------
       0 & 1    -0.0311     0.0152    -0.0237
       1 & 2    -0.2305    -0.2127    -0.2279
       2 & 3    -0.0370    -0.0254    -0.0345
       3 & 4     0.2245     0.2120     0.2135
       4 & 5    -0.0194     0.0068    -0.0210
       5 & 6    -0.0112     0.0196    -0.0044
       6 & 7    -0.1081    -0.0881    -0.1076
       7 & 8     0.1659     0.1631     0.1591
       8 & 9     0.0982     0.1182     0.1049
--------------------------------------------
     Average     0.0057     0.0232     0.0065
Out[17]:
Visualization
Bar chart comparing average cosine similarity between adjacent tokens for Original, Position Interpolation, and NTK-aware scaling methods.
Average cosine similarity between adjacent tokens for each scaling method. Lower similarity indicates more distinguishable positions. NTK-aware scaling preserves token distinguishability closer to the original, while Position Interpolation compresses the representation.

The visualization makes the pattern clear: Position Interpolation significantly increases similarity between adjacent tokens (making them harder to distinguish), while NTK-aware scaling keeps the similarity close to the original. This preservation of token distinguishability is the practical benefit of frequency-dependent scaling.

Dynamic NTK Scaling

Static NTK-aware scaling uses a fixed extension factor ss. But what if the actual sequence length varies? A model configured for 32k tokens shouldn't apply aggressive scaling when processing only 2k tokens. Ideally, the scaling should adapt to the current context: no modification for short sequences, progressive scaling as sequences grow longer.

Dynamic NTK scaling adapts the base in real-time based on the current sequence length:

b(L)=b(LLtrain)dd2b'(L) = b \cdot \left(\frac{L}{L_{\text{train}}}\right)^{\frac{d}{d-2}}

where:

  • b(L)b'(L): the dynamically computed base, now a function of sequence length LL
  • bb: the original base (typically 10000)
  • LL: the current sequence length being processed
  • LtrainL_{\text{train}}: the context length the model was trained with
  • L/LtrainL/L_{\text{train}}: the effective extension factor, computed on-the-fly
  • d/(d2)d/(d-2): the same exponent as in static NTK-aware scaling

The formula is identical to static NTK scaling, except that s=L/Ltrains = L/L_{\text{train}} is computed dynamically rather than fixed in advance.

When LLtrainL \leq L_{\text{train}}, the ratio L/Ltrain1L/L_{\text{train}} \leq 1, and we clamp it to 1 (no scaling). When LL exceeds the training length, scaling kicks in proportionally. A sequence of 8k tokens (2x the training length of 4k) gets s=2s = 2; a sequence of 32k tokens (8x) gets s=8s = 8.

In[18]:
Code
def dynamic_ntk_frequencies(d, base, train_length, current_length):
    """Compute dynamically scaled NTK-aware frequencies.

    Args:
        d: Embedding dimension
        base: Original RoPE base (typically 10000)
        train_length: Training context length
        current_length: Current sequence length

    Returns:
        Frequencies for each dimension pair
    """
    if current_length <= train_length:
        # No scaling needed
        return compute_rope_frequencies(d, base)

    # Dynamic scaling factor
    scale = current_length / train_length
    ntk_base = base * (scale ** (d / (d - 2)))
    return compute_rope_frequencies(d, ntk_base)
In[19]:
Code
# Demonstrate dynamic scaling at different sequence lengths
d = 64
base = 10000
train_length = 4096
test_lengths = [2048, 4096, 8192, 16384, 32768]

freq_by_length = {
    L: dynamic_ntk_frequencies(d, base, train_length, L) for L in test_lengths
}
Out[20]:
Console
Dynamic NTK scaling: effective base at different sequence lengths
Training length: 4096

  Seq Length    Scale   Effective Base
--------------------------------------
        2048     1.00            10000
        4096     1.00            10000
        8192     2.00            20452
       16384     4.00            41829
       32768     8.00            85550

The effective base scales smoothly with sequence length. At 32k tokens (8x the training length), the base increases to approximately 87k, matching our earlier static calculation.

Out[21]:
Visualization
Line plot showing how RoPE frequencies change with sequence length under dynamic NTK scaling, with different curves for different sequence lengths.
Dynamic NTK-aware frequencies adapt based on sequence length. Short sequences (at or below training length) use original frequencies, while longer sequences progressively modify frequencies to extend context.

Attention Score Analysis

To understand the practical impact, let's examine how attention scores behave under different scaling methods. We'll create synthetic queries and keys at various distances and measure the attention patterns:

In[22]:
Code
def compute_attention_scores(q, k):
    """Compute scaled dot-product attention scores."""
    d = q.shape[-1]
    scores = q @ k.T / np.sqrt(d)
    return scores


def create_position_aware_vectors(positions, d, frequencies):
    """Create vectors with RoPE applied at specified positions."""
    n = len(positions)
    # Use consistent base vectors
    np.random.seed(42)
    base_vectors = np.random.randn(n, d) * 0.1

    # Create rotation matrices for each position
    rotated = np.zeros_like(base_vectors)
    for idx, pos in enumerate(positions):
        angles = pos * frequencies
        cos_a = np.cos(angles)
        sin_a = np.sin(angles)

        for i in range(d // 2):
            x, y = base_vectors[idx, 2 * i], base_vectors[idx, 2 * i + 1]
            rotated[idx, 2 * i] = x * cos_a[i] - y * sin_a[i]
            rotated[idx, 2 * i + 1] = x * sin_a[i] + y * cos_a[i]

    return rotated
In[23]:
Code
# Examine attention between a query and keys at varying distances
d = 64
base = 10000
train_length = 4096
extended_length = 32768
scale = extended_length / train_length

# Different scaling approaches
freq_orig = compute_rope_frequencies(d, base)
freq_pi = position_interpolation_frequencies(d, base, scale)
freq_ntk = ntk_aware_frequencies(d, base, scale)

# Query at position 1000, keys at various distances
query_pos = 1000
key_distances = [1, 2, 5, 10, 50, 100, 500, 1000]
key_positions = [query_pos + dist for dist in key_distances]
Out[24]:
Visualization
Line plot showing attention score decay with distance for original, Position Interpolation, and NTK-aware scaling methods.
Relative attention scores as a function of distance from the query token. NTK-aware scaling maintains a decay pattern closer to the original, preserving the model's learned attention behavior for nearby tokens.

The attention decay pattern reveals a subtle but important distinction. While Position Interpolation flattens the attention curve (making nearby and distant tokens more similar in score), NTK-aware scaling preserves more of the original decay structure, especially at short distances.

Comparing NTK to Position Interpolation

Let's directly compare how well each method preserves the relative position encoding properties:

Out[25]:
Visualization
Line plot of dot product values vs relative position for original, Position Interpolation, and NTK-aware methods.
Dot product between rotated vectors as a function of relative position. The original RoPE curve shows strong sensitivity to position differences.
Line plot showing deviation from original dot product values for both scaling methods.
Deviation from original behavior. NTK-aware scaling (triangles) stays closer to zero, especially at small relative positions, indicating better preservation of learned patterns.

The deviation plot (right) makes the advantage clear: NTK-aware scaling maintains smaller deviations from the original behavior, particularly for small relative positions where local attention patterns matter most.

Interpolation Factor Analysis

An alternative perspective on NTK-aware scaling is to analyze the interpolation factor for each frequency. We can express NTK-aware scaling as a blend between "no interpolation" and "full interpolation":

In[26]:
Code
def compute_interpolation_factors(d, scale):
    """Compute effective interpolation factor for each dimension pair.

    Returns values between 0 (no interpolation, high freq preserved)
    and 1 (full interpolation).
    """
    dim_pairs = np.arange(d // 2)

    # NTK scaling formula: scale^(2i / (d-2))
    ntk_factors = scale ** (2 * dim_pairs / (d - 2))

    # Normalize to [0, 1] range: 1 means no scaling, scale means full scaling
    # interpolation = 0 when factor = 1, interpolation = 1 when factor = scale
    interpolation = (ntk_factors - 1) / (scale - 1)

    return interpolation
Out[27]:
Visualization
Line plot showing interpolation factor from 0 to 1 across dimension indices, with smooth S-curve transition.
Interpolation factor across dimension pairs under NTK-aware scaling. Low-frequency dimensions (right) receive full interpolation, while high-frequency dimensions (left) are preserved with minimal modification.

This interpolation perspective provides an intuitive understanding: NTK-aware scaling smoothly transitions from "preserve high frequencies" (interpolation factor near 0) to "fully scale low frequencies" (interpolation factor near 1).

Limitations and Practical Considerations

NTK-aware scaling works better than Position Interpolation, but it comes with its own trade-offs and limitations.

The core tension in context extension is that you cannot perfectly preserve all the properties of the original position encoding while also extending the context window. NTK-aware scaling prioritizes high-frequency preservation at the cost of more aggressive low-frequency compression. For tasks that rely heavily on precise long-range position information, this trade-off may not be ideal. Document retrieval, where finding a specific passage requires accurate global positioning, might suffer compared to summarization tasks where local coherence matters more.

Additionally, NTK-aware scaling, like Position Interpolation, typically requires fine-tuning to achieve optimal performance. While models may work "out of the box" with NTK scaling applied at inference time, the attention patterns learned during training assumed different frequency relationships. Fine-tuning on longer sequences helps the model adapt to the new frequency landscape. The amount of fine-tuning needed is generally comparable to Position Interpolation: a few hundred to a few thousand steps on representative long-context data often suffices.

The choice between static and dynamic NTK scaling depends on deployment constraints. Static scaling is simpler to implement and has deterministic behavior, but requires knowing the maximum context length in advance. Dynamic scaling adapts gracefully to varying sequence lengths but adds computational overhead, as frequencies must be recomputed based on current sequence length. For most applications where sequence lengths are predictable, static scaling is sufficient.

Finally, NTK-aware scaling is specific to RoPE and doesn't transfer to other position encoding schemes. Models using absolute position embeddings, ALiBi, or other mechanisms require different extension strategies. This limits the generality of the approach, though the dominance of RoPE in modern architectures means NTK-aware scaling is broadly applicable.

Key Parameters

When implementing NTK-aware scaling, several parameters control the behavior of the context extension:

  • base (default: 10000): The original RoPE base constant. This value comes from the pretrained model and should match what was used during training. Common values are 10000 (LLaMA, Mistral) or 500000 (some newer models with extended context).

  • scale (s): The context extension factor, computed as target length divided by training length. For example, extending from 4k to 32k tokens gives s=8s = 8. Larger values enable longer contexts but increase the distortion from original frequencies.

  • d (embedding dimension): The model's embedding dimension, used in computing the NTK exponent d/(d2)d/(d-2). This is fixed by the model architecture and typically ranges from 64 to 128 for the head dimension in modern transformers.

  • train_length (for dynamic scaling): The context length the model was originally trained with. This serves as the threshold below which no scaling is applied. Using the correct value is critical for dynamic NTK to work properly.

When choosing between static and dynamic scaling, consider your deployment scenario. Static scaling with a fixed ss is simpler and more predictable, suitable when you know the maximum sequence length in advance. Dynamic scaling adapts to varying inputs but adds computational overhead for recomputing frequencies.

Summary

NTK-aware scaling provides a principled approach to extending context length in RoPE-based models by recognizing that different frequency bands serve different purposes:

  • Key insight: High-frequency components distinguish nearby tokens; low-frequency components capture long-range structure. Uniform scaling damages local precision unnecessarily.

  • The NTK formula: Replace base bb with b=bsd/(d2)b' = b \cdot s^{d/(d-2)}, where ss is the context extension factor. This preserves high frequencies while scaling low frequencies.

  • Frequency-dependent scaling: The effective scaling factor increases from 1 (no change) for the highest frequencies to ss (full scaling) for the lowest frequencies.

  • Dynamic variant: Adapt the base in real-time based on current sequence length: b(L)=b(L/Ltrain)d/(d2)b'(L) = b \cdot (L/L_{\text{train}})^{d/(d-2)}.

  • Practical benefits: Better preservation of local attention patterns compared to Position Interpolation, especially for tasks requiring precise nearby-token relationships.

NTK-aware scaling improved context extension, but researchers continued seeking even better solutions. In the next chapter, we'll explore YaRN (Yet another RoPE extension method), which builds on NTK-aware principles while adding attention scaling and temperature adjustments for further improvements.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about NTK-aware scaling and context extension in transformer models.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{ntkawarescalingextendingcontextlengthinllms, author = {Michael Brenndoerfer}, title = {NTK-aware Scaling: Extending Context Length in LLMs}, year = {2025}, url = {https://mbrenndoerfer.com/writing/ntk-aware-scaling-context-extension}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). NTK-aware Scaling: Extending Context Length in LLMs. Retrieved from https://mbrenndoerfer.com/writing/ntk-aware-scaling-context-extension
MLAAcademic
Michael Brenndoerfer. "NTK-aware Scaling: Extending Context Length in LLMs." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/ntk-aware-scaling-context-extension>.
CHICAGOAcademic
Michael Brenndoerfer. "NTK-aware Scaling: Extending Context Length in LLMs." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/ntk-aware-scaling-context-extension.
HARVARDAcademic
Michael Brenndoerfer (2025) 'NTK-aware Scaling: Extending Context Length in LLMs'. Available at: https://mbrenndoerfer.com/writing/ntk-aware-scaling-context-extension (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). NTK-aware Scaling: Extending Context Length in LLMs. https://mbrenndoerfer.com/writing/ntk-aware-scaling-context-extension
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free