Search

Search articles

Recurrent Memory: Extending Transformer Context with Segment-Level State Caching

Michael BrenndoerferUpdated July 6, 202550 min read

Learn how Transformer-XL uses segment-level recurrence to extend effective context length by caching hidden states, why relative position encodings are essential for cross-segment attention, and when recurrent memory approaches outperform standard transformers.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Recurrent Memory

Transformers process sequences in fixed-length segments. When a document exceeds the context window, the standard approach is to truncate or split it, processing each piece independently. But this independence comes at a cost: information from earlier segments vanishes entirely. A pronoun in segment 3 cannot resolve to its antecedent in segment 1 because segment 1 no longer exists in the model's view.

Transformer-XL introduced a solution that seems almost obvious in retrospect: what if we kept the hidden states from the previous segment and let the current segment attend to them? This segment-level recurrence creates a form of memory that extends effective context far beyond the training sequence length. The model processes sequences one segment at a time, but each segment can "remember" what came before through cached hidden states.

This chapter explores how Transformer-XL implements recurrent memory, why it requires relative positional encodings, and what limitations remain. Understanding this approach illuminates a key tension in long-context modeling: the tradeoff between computational efficiency and true bidirectional context.

The Segment Boundary Problem

Standard transformers suffer from what the Transformer-XL paper calls "context fragmentation." When you split a long document into fixed-length segments, each segment is processed without any knowledge of its neighbors. The model sees each chunk as an independent sequence.

Context Fragmentation

Context fragmentation occurs when a transformer processes a long sequence in fixed-length segments without information flow between them. Each segment starts fresh, losing all context from previous segments regardless of semantic continuity.

Consider processing a 4096-token document with a 512-token context window. The naive approach splits this into 8 segments, processing each independently. Token 513 (the first token of segment 2) cannot attend to token 512 (the last token of segment 1). They're as disconnected as tokens from completely different documents.

In[2]:
Code
# Simulate context fragmentation
document_length = 4096
segment_length = 512
num_segments = document_length // segment_length

# For each segment, the maximum dependency distance is limited
# to within-segment connections only
max_within_segment_distance = segment_length - 1
max_possible_distance = document_length - 1

# Calculate the fraction of potential dependencies that are visible
visible_dependencies_per_segment = segment_length * (segment_length - 1) // 2
total_possible_dependencies = document_length * (document_length - 1) // 2
visibility_fraction = (
    visible_dependencies_per_segment * num_segments
) / total_possible_dependencies
Out[3]:
Console
Document length: 4,096 tokens
Segment length: 512 tokens
Number of segments: 8

Within each segment:
  Maximum attention distance: 511 tokens
  Visible dependencies per segment: 130,816

Across entire document:
  Total possible dependencies: 8,386,560
  Fraction visible with fragmentation: 12.48%

Less than 13% of potential token-to-token dependencies are visible when processing with context fragmentation. Cross-segment dependencies, which may carry crucial information like coreference chains or long-range discourse structure, are completely invisible.

Out[4]:
Visualization
Block diagonal attention pattern showing 4 isolated segments with no cross-segment connections.
Context fragmentation in standard transformer processing. Each segment is processed independently, with no information flow across segment boundaries. The attention matrix shows isolated blocks along the diagonal, indicating that tokens can only attend within their own segment.

The attention pattern reveals the fundamental limitation: each segment is an island. No matter how important a reference in segment 1 might be for understanding segment 4, that information cannot flow through the attention mechanism.

Transformer-XL: Segment-Level Recurrence

Now that we understand the problem, let's explore how Transformer-XL solves it. The solution is remarkably simple: instead of discarding the previous segment entirely, cache its hidden states and make them available during attention computation for the current segment.

Think of it like this: when you read a new paragraph, you don't forget the previous one. You carry forward a mental summary of what came before. That summary isn't the original words themselves, but your processed understanding of them. Transformer-XL does exactly this, but with hidden states instead of mental summaries.

Segment-Level Recurrence

Segment-level recurrence is a technique where hidden states from the previous segment are cached and concatenated with the current segment's keys and values during attention computation. This allows information to flow across segment boundaries without recomputing attention over the entire history.

The Mechanism Step by Step

The recurrence operates at each layer independently. When processing layer nn of segment τ\tau, the model:

  1. Retrieves the cached hidden states hτ1n1\mathbf{h}_{\tau-1}^{n-1} from processing the previous segment at layer n1n-1
  2. Concatenates these cached states with the current segment's hidden states
  3. Computes attention where queries come only from the current segment, but keys and values span both the cached and current states
  4. Caches the current segment's output hidden states for use when processing the next segment

This creates an asymmetric attention pattern: current tokens can "look back" at cached tokens, but we never recompute outputs for the cached tokens. They're frozen representations from the previous forward pass.

The Mathematical Formulation

Let's formalize this mechanism. Suppose we're processing segment τ\tau, which contains LL tokens. The previous segment's hidden states, also of length LL (or some memory length MM), have been cached. At layer nn, we want to compute attention over an extended context that includes both segments.

The extended context for attention at layer nn is:

h~τn1=StopGrad(hτ1n1)hτn1\tilde{\mathbf{h}}_\tau^{n-1} = \text{StopGrad}(\mathbf{h}_{\tau-1}^{n-1}) \circ \mathbf{h}_\tau^{n-1}

where:

  • h~τn1\tilde{\mathbf{h}}_\tau^{n-1}: the extended hidden states combining previous and current segments
  • hτ1n1\mathbf{h}_{\tau-1}^{n-1}: cached hidden states from the previous segment at layer n1n-1
  • hτn1\mathbf{h}_\tau^{n-1}: current segment's hidden states at layer n1n-1
  • StopGrad()\text{StopGrad}(\cdot): stops gradient flow to prevent backpropagating through the cached states
  • \circ: concatenation along the sequence dimension

The attention computation then becomes:

qτn=hτn1Wqn,kτn=h~τn1Wkn,vτn=h~τn1Wvn\mathbf{q}_\tau^n = \mathbf{h}_\tau^{n-1} W_q^n, \quad \mathbf{k}_\tau^n = \tilde{\mathbf{h}}_\tau^{n-1} W_k^n, \quad \mathbf{v}_\tau^n = \tilde{\mathbf{h}}_\tau^{n-1} W_v^n hτn=Attention(qτn,kτn,vτn)\mathbf{h}_\tau^n = \text{Attention}(\mathbf{q}_\tau^n, \mathbf{k}_\tau^n, \mathbf{v}_\tau^n)

where:

  • qτnRL×d\mathbf{q}_\tau^n \in \mathbb{R}^{L \times d}: queries from the current segment only
  • kτn,vτnR(L+M)×d\mathbf{k}_\tau^n, \mathbf{v}_\tau^n \in \mathbb{R}^{(L + M) \times d}: keys and values from the extended context (current + cached)
  • Wqn,Wkn,WvnW_q^n, W_k^n, W_v^n: learnable projection matrices for layer nn
  • MM: the length of the cached memory (typically equal to segment length LL)

The critical detail is that queries come only from the current segment while keys and values include the cached previous segment. This asymmetry is intentional: we want current tokens to attend to past context, but we don't want to regenerate outputs for past tokens. The cached states are read-only: they provide context but don't receive updates.

Implementing the Mechanism

Let's translate this into code. The core operation is straightforward: concatenate cached and current hidden states, project to queries/keys/values, and compute attention with appropriate masking.

In[5]:
Code
import numpy as np


def transformer_xl_attention(
    current_hidden: np.ndarray,
    cached_hidden: np.ndarray,
    W_q: np.ndarray,
    W_k: np.ndarray,
    W_v: np.ndarray,
    d_k: int,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Compute Transformer-XL style attention with segment-level recurrence.

    Args:
        current_hidden: Current segment hidden states, shape (L, d)
        cached_hidden: Cached previous segment states, shape (M, d)
        W_q, W_k, W_v: Projection matrices, shape (d, d)
        d_k: Key dimension for scaling

    Returns:
        attention_output: Output for current segment, shape (L, d)
        attention_weights: Attention pattern, shape (L, L+M)
    """
    # Concatenate for extended context
    extended_hidden = np.concatenate([cached_hidden, current_hidden], axis=0)

    # Project to queries, keys, values
    # Queries: only from current segment
    queries = current_hidden @ W_q  # (L, d)
    # Keys and values: from extended context
    keys = extended_hidden @ W_k  # (L+M, d)
    values = extended_hidden @ W_v  # (L+M, d)

    # Compute attention scores
    scores = queries @ keys.T / np.sqrt(d_k)  # (L, L+M)

    # Apply causal masking
    # Current segment tokens can attend to all of cache + their own past
    L = current_hidden.shape[0]
    M = cached_hidden.shape[0]
    mask = np.ones_like(scores) * float("-inf")
    for i in range(L):
        # Can attend to all cached tokens + tokens up to and including position i
        mask[i, : M + i + 1] = 0
    scores = scores + mask

    # Softmax
    exp_scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
    attention_weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True)

    # Weighted sum of values
    output = attention_weights @ values  # (L, d)

    return output, attention_weights
In[6]:
Code
# Demonstrate the mechanism with a simple example
np.random.seed(42)

segment_length = 8
memory_length = 8
hidden_dim = 16
d_k = hidden_dim

# Simulate hidden states
current_segment = np.random.randn(segment_length, hidden_dim) * 0.1
cached_segment = np.random.randn(memory_length, hidden_dim) * 0.1

# Random projection matrices (in practice, these are learned)
W_q = np.random.randn(hidden_dim, hidden_dim) * 0.1
W_k = np.random.randn(hidden_dim, hidden_dim) * 0.1
W_v = np.random.randn(hidden_dim, hidden_dim) * 0.1

output, attention_weights = transformer_xl_attention(
    current_segment, cached_segment, W_q, W_k, W_v, d_k
)
Out[7]:
Console
Current segment shape: (8, 16)
Cached segment shape: (8, 16)
Attention weights shape: (8, 16)
Output shape: (8, 16)

Attention weight distribution for token 0 of current segment:
  Attention to cached tokens (0-7): 0.889
  Attention to current token (8): 0.111

The output shows that token 0 of the current segment allocates substantial attention to the cached memory. This is expected: with no preceding tokens in the current segment, the memory provides all available context. The attention weights reveal how information flows. The first token in the current segment can attend to all 8 cached tokens plus itself. Later tokens in the current segment can attend to even more context: all cached tokens plus all preceding tokens in the current segment.

Out[8]:
Visualization
Heatmap showing attention weights with a vertical segment boundary, where current tokens attend to both cached and current positions.
Transformer-XL attention pattern showing segment-level recurrence. Current segment tokens (positions 8-15) can attend to cached previous segment tokens (positions 0-7) as well as their own causal context. The dashed line marks the segment boundary.

The attention pattern shows the distinctive Transformer-XL signature: a triangular pattern in the current segment (causal self-attention) combined with a rectangular region on the left (attention to cached memory). Every token in the current segment can see the entire cached memory.

Effective Context Length

The recurrence mechanism creates a dependency chain across segments. Information from segment 1 flows to segment 2 through the cached hidden states. Segment 2's hidden states then carry forward information to segment 3. This chain means the effective context length grows beyond a single segment, bounded by how far information can propagate through the hidden states.

For an NN-layer transformer with segment length LL, the maximum effective context length is O(N×L)O(N \times L). To understand why, consider how information propagates layer by layer:

  • Layer 1 of segment τ\tau receives cached states from layer 0 of segment τ1\tau - 1, giving it access to 1 previous segment
  • Layer 2 of segment τ\tau receives cached states from layer 1 of segment τ1\tau - 1, which already incorporated information from segment τ2\tau - 2 at its own layer 1
  • Layer nn can potentially access information from nn segments back through this chain of cached representations

The depth of the network acts as a multiplier on the effective context window.

In[9]:
Code
def compute_effective_context(num_layers: int, segment_length: int) -> dict:
    """
    Compute the effective context length for Transformer-XL.

    The effective context grows with depth because information
    propagates one segment further back at each layer.
    """
    # At layer n, information can potentially come from n segments back
    # This is because cached states at layer n contain information
    # that was already aggregated from previous segments at layer n-1

    # Direct attention span (what a single layer can see)
    direct_span = 2 * segment_length  # current + one cached segment

    # Maximum theoretical span considering all layers
    max_theoretical_span = num_layers * segment_length + segment_length

    # Information decay means effective span is somewhat less
    # Upper layers have access to more distant information but it's diluted

    return {
        "num_layers": num_layers,
        "segment_length": segment_length,
        "direct_attention_span": direct_span,
        "max_theoretical_span": max_theoretical_span,
        "context_multiplier": max_theoretical_span / segment_length,
    }


# Compare different configurations
configs = [
    (6, 128),  # Small model
    (12, 256),  # Medium model
    (24, 512),  # Large model
]

results = [compute_effective_context(n, l) for n, l in configs]
Out[10]:
Console
Effective Context Length in Transformer-XL

Configuration         Direct Span Max Theoretical   Multiplier
--------------------------------------------------------------
6L × 128L                     256             896          7.0x
12L × 256L                    512           3,328         13.0x
24L × 512L                  1,024          12,800         25.0x

A 24-layer model with 512-token segments can theoretically access information from over 13,000 tokens ago, even though it only directly attends to 1,024 tokens per layer. The depth of the network amplifies the effective context.

Out[11]:
Visualization
Line plot showing linear growth of effective context with layer depth, compared to constant direct attention span.
Effective context growth with layer depth in Transformer-XL. The maximum theoretical context (solid line) grows linearly with the number of layers, while direct attention span (dashed line) remains constant. For a 24-layer model with 512-token segments, the effective context exceeds 12,000 tokens.
Out[12]:
Visualization
Diagram showing expanding context reach at higher transformer layers, with layer 1 seeing 1 previous segment and layer 6 seeing 6 previous segments.
Information flow across layers and segments in Transformer-XL. Deeper layers have access to information from more distant segments because each layer's cached states already incorporate aggregated information from previous segments. The colored bands show the reachable context at each layer.

The visualization shows how context reach expands with network depth. Layer 1 can only see the immediately preceding segment through cached states. Layer 6, however, has access to information from 6 segments back because the cached states at layer 5 already incorporated information that propagated through the previous segment's entire network.

The Position Encoding Problem

Standard absolute position encodings break under segment-level recurrence. If we use learned or sinusoidal position embeddings based on absolute positions within each segment, we encounter a fundamental inconsistency.

Consider a token at position 5 in segment τ\tau. In the previous segment τ1\tau - 1, there was also a token at position 5. Both receive the same absolute position encoding. But from the perspective of the current segment, these tokens are at very different distances: position 5 in the current segment is "here," while position 5 in the cached segment is 8 positions back (if segments have length 8).

In[13]:
Code
def demonstrate_position_conflict():
    """
    Show how absolute position encodings create confusion
    in segment-level recurrence.
    """
    segment_length = 8

    # Absolute positions assigned during each segment's processing
    segment_tau_minus_1_positions = list(
        range(segment_length)
    )  # [0, 1, 2, ..., 7]
    segment_tau_positions = list(range(segment_length))  # [0, 1, 2, ..., 7]

    # True temporal distances from perspective of segment tau
    # Cached tokens are at positions -8 to -1 relative to segment tau's start
    true_distances_cached = list(range(-segment_length, 0))  # [-8, -7, ..., -1]
    true_distances_current = list(range(segment_length))  # [0, 1, ..., 7]

    return {
        "cached_absolute_pos": segment_tau_minus_1_positions,
        "current_absolute_pos": segment_tau_positions,
        "cached_true_distance": true_distances_cached,
        "current_true_distance": true_distances_current,
    }


pos_info = demonstrate_position_conflict()
Out[14]:
Console
Position Encoding Conflict in Segment-Level Recurrence

Cached Segment (τ-1):
  Absolute positions (as encoded): [0, 1, 2, 3, 4, 5, 6, 7]
  True distance from current τ:    [-8, -7, -6, -5, -4, -3, -2, -1]

Current Segment (τ):
  Absolute positions (as encoded): [0, 1, 2, 3, 4, 5, 6, 7]
  True distance from current τ:    [0, 1, 2, 3, 4, 5, 6, 7]

Problem: Position 5 in cached segment and position 5 in current segment
both have the same absolute encoding, but their true distances differ by 8!

If absolute position encodings were used, the model would receive conflicting signals. Two tokens with identical position encodings would be at different temporal distances. The attention mechanism, which relies on position information to understand sequence structure, would be confused.

Transformer-XL solves this with relative position encodings. Instead of encoding absolute positions and adding them to token embeddings, the model directly encodes the relative distance between query and key positions in the attention computation itself.

Relative Position Encoding in Transformer-XL

We've established that segment-level recurrence breaks absolute position encodings. A token at position 5 in the cached segment and position 5 in the current segment have the same absolute encoding, yet they are 8 positions apart from the perspective of the current segment. How do we fix this? The answer lies in rethinking what position information attention actually needs.

Why Relative Distance Matters

When you read a sentence like "The cat sat on the mat because it was tired," you understand that "it" refers to "cat" not because of their absolute positions in the document, but because of their relative proximity. The pronoun comes shortly after its antecedent. This observation is the key insight: attention cares about how far apart tokens are, not where they are in absolute terms.

Consider two identical queries, one at position 10 and one at position 100, both attending to keys that are 3 positions before them. If the tokens involved have the same content, shouldn't these attention computations behave similarly? With absolute position encodings, they don't, because positions 7 and 97 have completely different encodings. With relative position encodings, they do, because "3 positions back" always means the same thing.

Decomposing Standard Attention

To understand how Transformer-XL achieves relative position encoding, we need to first dissect how position information enters standard attention. In the original transformer, the attention score between a query at position ii and a key at position jj starts as a simple dot product:

scoreij=qikj\text{score}_{ij} = \mathbf{q}_i^\top \mathbf{k}_j

where:

  • scoreij\text{score}_{ij}: the attention score determining how much position ii attends to position jj
  • qi\mathbf{q}_i: the query vector at position ii
  • kj\mathbf{k}_j: the key vector at position jj

But where does position come in? The original transformer adds position embeddings to token embeddings before projecting to queries and keys. So the query at position ii is actually Wq(xi+pi)W_q(\mathbf{x}_i + \mathbf{p}_i) and the key at position jj is Wk(xj+pj)W_k(\mathbf{x}_j + \mathbf{p}_j). Substituting these into the dot product:

scoreij=(xi+pi)WqWk(xj+pj)\text{score}_{ij} = (\mathbf{x}_i + \mathbf{p}_i)^\top W_q^\top W_k (\mathbf{x}_j + \mathbf{p}_j)

where:

  • xi,xj\mathbf{x}_i, \mathbf{x}_j: token embeddings at positions ii and jj
  • pi,pj\mathbf{p}_i, \mathbf{p}_j: absolute position embeddings for positions ii and jj
  • Wq,WkW_q, W_k: learnable query and key projection matrices

This is where the magic of algebra reveals hidden structure. When we expand this product using the distributive property, we get four distinct terms:

scoreij=xiWqWkxjcontent-content+xiWqWkpjcontent-position+piWqWkxjposition-content+piWqWkpjposition-position\text{score}_{ij} = \underbrace{\mathbf{x}_i^\top W_q^\top W_k \mathbf{x}_j}_{\text{content-content}} + \underbrace{\mathbf{x}_i^\top W_q^\top W_k \mathbf{p}_j}_{\text{content-position}} + \underbrace{\mathbf{p}_i^\top W_q^\top W_k \mathbf{x}_j}_{\text{position-content}} + \underbrace{\mathbf{p}_i^\top W_q^\top W_k \mathbf{p}_j}_{\text{position-position}}

Each term tells us something different about why one token might attend to another:

  • Content-content: "Does this token's meaning relate to that token's meaning?" This is pure semantic matching, independent of where the tokens appear.
  • Content-position: "Given what this token is looking for, does that position matter?" For example, a verb might preferentially attend to its subject, which typically precedes it.
  • Position-content: "Given where this token is, does that token's content matter more?" Early positions might attend differently than late positions.
  • Position-position: "Do these two positions have an inherent affinity?" Adjacent positions might naturally attend to each other.

From Absolute to Relative

The problem with this decomposition is that all position information uses absolute positions pi\mathbf{p}_i and pj\mathbf{p}_j. Transformer-XL's insight is that we can rewrite these terms to use relative position instead. The redesign makes two key changes:

  1. Replace the key's absolute position with relative distance: Instead of pj\mathbf{p}_j (the absolute position of the key), use rij\mathbf{r}_{i-j} (the relative distance from query to key).

  2. Replace the query's position with a learned global bias: The query's absolute position pi\mathbf{p}_i becomes learned vectors u\mathbf{u} and v\mathbf{v} that don't depend on position at all.

The resulting formula is:

scoreij=xiWqWk,Exj(a)+xiWqWk,Rrij(b)+uWk,Exj(c)+vWk,Rrij(d)\text{score}_{ij} = \underbrace{\mathbf{x}_i^\top W_q^\top W_{k,E} \mathbf{x}_j}_{(a)} + \underbrace{\mathbf{x}_i^\top W_q^\top W_{k,R} \mathbf{r}_{i-j}}_{(b)} + \underbrace{\mathbf{u}^\top W_{k,E} \mathbf{x}_j}_{(c)} + \underbrace{\mathbf{v}^\top W_{k,R} \mathbf{r}_{i-j}}_{(d)}

Let's unpack each component:

  • xi,xj\mathbf{x}_i, \mathbf{x}_j: token embeddings at positions ii and jj, unchanged from before
  • rij\mathbf{r}_{i-j}: a sinusoidal encoding of the relative distance iji - j, not the absolute position
  • Wk,EW_{k,E}: key projection matrix for content (the "E" stands for embeddings)
  • Wk,RW_{k,R}: key projection matrix for relative positions (the "R" stands for relative)
  • u\mathbf{u}: a learned global bias for content attention, shared across all query positions
  • v\mathbf{v}: a learned global bias for position attention, also shared across all positions

Why does this work? Consider what each term now captures:

  1. Term (a): Pure content-based attention. This is unchanged from standard attention. The word "cat" attends to "feline" because of semantic similarity, regardless of position.

  2. Term (b): Content-dependent distance preference. The query content determines how much the model cares about distance. A pronoun might strongly prefer nearby tokens, while a discourse marker might look further back.

  3. Term (c): Global content importance. Some tokens are just important regardless of the query's position. The beginning-of-sentence token might receive attention from everywhere.

  4. Term (d): Global distance preference. The model learns a general preference for certain distances. Typically, nearby tokens receive more attention than distant ones.

The crucial insight is that terms (b) and (d) now depend on iji - j rather than on ii and jj separately. This means the position signal is the same whether we're at the start of the document or the end, whether we're attending within the current segment or reaching back into cached memory.

Building the Relative Encoding

With the theory in place, let's implement relative position encoding step by step. We need two components: (1) a function to generate sinusoidal encodings for each possible relative distance, and (2) a function to compute attention scores using the four-term formula.

The sinusoidal encoding for relative positions works similarly to absolute position encodings, but instead of encoding absolute positions 0, 1, 2, ..., we encode relative distances ..., -2, -1, 0, 1, 2, .... Negative distances mean the key is before the query; positive distances mean the key is after the query (though in causal attention, we only see non-positive distances).

In[15]:
Code
import numpy as np


def sinusoidal_relative_encoding(max_distance: int, d_model: int) -> np.ndarray:
    """
    Generate sinusoidal encodings for relative positions.

    Unlike absolute encodings, these encode the distance between positions,
    allowing the same encoding for any pair at the same distance.
    """
    positions = np.arange(-max_distance, max_distance + 1)
    encodings = np.zeros((len(positions), d_model))

    for i, pos in enumerate(positions):
        for j in range(0, d_model, 2):
            div_term = 10000 ** (j / d_model)
            encodings[i, j] = np.sin(pos / div_term)
            if j + 1 < d_model:
                encodings[i, j + 1] = np.cos(pos / div_term)

    return positions, encodings


def relative_attention_scores(
    queries: np.ndarray,
    keys: np.ndarray,
    relative_encodings: np.ndarray,
    rel_positions: np.ndarray,
    u: np.ndarray,
    v: np.ndarray,
    W_k_E: np.ndarray,
    W_k_R: np.ndarray,
) -> np.ndarray:
    """
    Compute Transformer-XL style attention scores with relative positions.

    Args:
        queries: Query vectors, shape (L, d)
        keys: Key vectors (from extended context), shape (L+M, d)
        relative_encodings: Sinusoidal encodings indexed by distance
        rel_positions: Array mapping distance to encoding index
        u, v: Global bias vectors, shape (d,)
        W_k_E, W_k_R: Key projections for content and position

    Returns:
        Attention scores, shape (L, L+M)
    """
    L = queries.shape[0]
    K = keys.shape[0]  # L + M (current + cached)
    M = K - L  # Memory/cache length

    # Term (a): content-to-content
    term_a = queries @ W_k_E @ keys.T  # (L, K)

    # Term (b): content-to-relative-position
    # For each query position i and key position j, we need r_{i-j}
    term_b = np.zeros((L, K))
    for i in range(L):
        for j in range(K):
            # Relative distance: query position (0 to L-1) vs key position (-M to L-1)
            # Key positions 0 to M-1 correspond to cached tokens at positions -M to -1
            # Key positions M to M+L-1 correspond to current tokens at positions 0 to L-1
            if j < M:
                key_pos = j - M  # Negative for cached tokens
            else:
                key_pos = j - M  # 0 to L-1 for current tokens
            query_pos = i
            distance = query_pos - key_pos

            # Find encoding for this distance
            enc_idx = np.where(rel_positions == distance)[0]
            if len(enc_idx) > 0:
                r_ij = relative_encodings[enc_idx[0]]
                term_b[i, j] = queries[i] @ W_k_R @ r_ij

    # Term (c): global content bias
    term_c = u @ W_k_E @ keys.T  # (K,) broadcast to (L, K)
    term_c = np.tile(term_c, (L, 1))

    # Term (d): global position bias
    term_d = np.zeros((L, K))
    for i in range(L):
        for j in range(K):
            if j < M:
                key_pos = j - M
            else:
                key_pos = j - M
            query_pos = i
            distance = query_pos - key_pos

            enc_idx = np.where(rel_positions == distance)[0]
            if len(enc_idx) > 0:
                r_ij = relative_encodings[enc_idx[0]]
                term_d[i, j] = v @ W_k_R @ r_ij

    return term_a + term_b + term_c + term_d
In[16]:
Code
# Demonstrate relative position encoding
np.random.seed(42)

L = 4  # Current segment length
M = 4  # Cached segment length
d = 8  # Model dimension

# Generate relative position encodings
max_dist = L + M
rel_positions, rel_encodings = sinusoidal_relative_encoding(max_dist, d)

# Random queries and keys (in practice, these come from token embeddings)
queries = np.random.randn(L, d) * 0.5
keys = np.random.randn(L + M, d) * 0.5

# Learned parameters
u = np.random.randn(d) * 0.1
v = np.random.randn(d) * 0.1
W_k_E = np.eye(d) + np.random.randn(d, d) * 0.1
W_k_R = np.eye(d) + np.random.randn(d, d) * 0.1

scores = relative_attention_scores(
    queries, keys, rel_encodings, rel_positions, u, v, W_k_E, W_k_R
)
Out[17]:
Console
Attention score matrix shape: (4, 8)

Scores for query position 0 (first token of current segment):
  To cached positions (distances 4-7): [0.57475205 1.64267582 2.18688392 1.10125465]
  To current positions (distances 0-3): [2.05152963 0.1093185  1.47634989 0.76901824]

Scores for query position 3 (last token of current segment):
  To cached positions (distances 7-10): [0.77567855 1.66560434 1.37047954 1.22962436]
  To current positions (distances 0-3): [1.45163637 0.78473083 0.56027416 0.80859854]

The scores vary based on both content similarity and relative distance. Notice that the scores to cached positions (which are further away) differ from scores to current positions (which are closer). The relative position encoding ensures consistent treatment of distance regardless of absolute segment boundaries. A token attending to something 3 positions back receives the same position signal whether that's within the current segment or reaching into the cached memory.

To better understand how these four terms contribute to the final attention score, let's decompose them for a single query-key pair across different relative distances:

Out[18]:
Visualization
Individual term contributions to attention scores. Each bar shows one of the four terms at each relative distance.
Individual term contributions to attention scores. Each bar shows one of the four terms at each relative distance.
Combined attention score with stacked breakdown showing how content and position terms interact.
Combined attention score with stacked breakdown showing how content and position terms interact.

The visualization shows how content similarity (term a) provides the base signal, while position terms (b, c, d) modulate attention based on distance. The strong preference for position 0 (self-attention) reflects both high content similarity and the global position bias favoring nearby tokens.

Out[19]:
Visualization
Heatmap showing encoding similarity between relative positions from -8 to 8, with high similarity along the diagonal.
Relative position encoding similarity matrix showing how the sinusoidal encoding captures distance relationships. Nearby positions have similar encodings, with similarity decaying smoothly for larger distances. The symmetry around the diagonal reflects the encoding's ability to represent both positive and negative relative distances.

The similarity matrix shows that encodings for nearby relative positions are similar, gradually diverging for larger distances. This smooth decay allows the attention mechanism to naturally prefer nearby tokens while still accessing distant context when needed.

Implementing Transformer-XL

We've now covered the two core innovations of Transformer-XL: segment-level recurrence for extending context across segment boundaries, and relative position encoding for handling positions consistently across segments. Let's bring these pieces together into a complete implementation.

A Transformer-XL layer follows the same structure as a standard transformer layer: attention followed by feed-forward, with residual connections and layer normalization. The key differences are in the attention computation, where we must handle cached memory and compute relative position biases.

In[20]:
Code
import numpy as np


class TransformerXLLayer:
    """
    A single Transformer-XL layer with segment-level recurrence
    and relative position encoding.
    """

    def __init__(
        self, d_model: int, n_heads: int, d_ff: int, max_rel_dist: int = 512
    ):
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.max_rel_dist = max_rel_dist

        # Initialize parameters (simplified: using random initialization)
        np.random.seed(42)
        scale = 0.02

        # Attention projections
        self.W_q = np.random.randn(d_model, d_model) * scale
        self.W_k_E = np.random.randn(d_model, d_model) * scale  # Content
        self.W_k_R = np.random.randn(d_model, d_model) * scale  # Position
        self.W_v = np.random.randn(d_model, d_model) * scale
        self.W_o = np.random.randn(d_model, d_model) * scale

        # Global biases for relative attention
        self.u = np.random.randn(d_model) * scale
        self.v = np.random.randn(d_model) * scale

        # Feed-forward network
        self.W_ff1 = np.random.randn(d_model, d_ff) * scale
        self.b_ff1 = np.zeros(d_ff)
        self.W_ff2 = np.random.randn(d_ff, d_model) * scale
        self.b_ff2 = np.zeros(d_model)

        # Layer norms (simplified: just storing means and variances)
        self.ln1_gamma = np.ones(d_model)
        self.ln1_beta = np.zeros(d_model)
        self.ln2_gamma = np.ones(d_model)
        self.ln2_beta = np.zeros(d_model)

        # Precompute relative position encodings
        self.rel_positions, self.rel_encodings = sinusoidal_relative_encoding(
            max_rel_dist, d_model
        )

    def layer_norm(
        self, x: np.ndarray, gamma: np.ndarray, beta: np.ndarray
    ) -> np.ndarray:
        """Apply layer normalization."""
        mean = x.mean(axis=-1, keepdims=True)
        std = x.std(axis=-1, keepdims=True) + 1e-6
        return gamma * (x - mean) / std + beta

    def relative_attention(
        self,
        hidden: np.ndarray,
        memory: np.ndarray,
    ) -> np.ndarray:
        """
        Compute relative multi-head attention with cached memory.

        Args:
            hidden: Current segment hidden states, shape (L, d_model)
            memory: Cached previous segment states, shape (M, d_model)

        Returns:
            Attention output, shape (L, d_model)
        """
        L = hidden.shape[0]
        M = memory.shape[0] if memory is not None else 0

        # Concatenate memory and hidden for keys/values
        if memory is not None and M > 0:
            extended = np.concatenate([memory, hidden], axis=0)
        else:
            extended = hidden
        K = extended.shape[0]

        # Compute queries, keys, values
        queries = hidden @ self.W_q  # (L, d)
        keys_E = extended @ self.W_k_E  # (K, d)
        values = extended @ self.W_v  # (K, d)

        # Compute attention scores with relative positions
        # Term (a): content-to-content
        scores = queries @ keys_E.T  # (L, K)

        # Terms (b), (c), (d): relative position terms
        # Simplified: we add positional bias based on distance
        for i in range(L):
            for j in range(K):
                # Compute relative distance
                if j < M:
                    key_pos = j - M  # Negative for memory
                else:
                    key_pos = j - M  # 0 to L-1 for current
                query_pos = i
                distance = query_pos - key_pos

                # Find relative encoding
                idx = np.where(
                    self.rel_positions
                    == np.clip(distance, -self.max_rel_dist, self.max_rel_dist)
                )[0]
                if len(idx) > 0:
                    r_ij = self.rel_encodings[idx[0]]
                    # Add position bias terms
                    scores[i, j] += (queries[i] + self.u) @ self.W_k_R @ r_ij
                    scores[i, j] += self.v @ self.W_k_R @ r_ij

        # Scale
        scores = scores / np.sqrt(self.d_head)

        # Causal mask (only for current segment attending to past)
        mask = np.ones_like(scores) * float("-inf")
        for i in range(L):
            # Can attend to all memory + tokens up to and including position i
            mask[i, : M + i + 1] = 0
        scores = scores + mask

        # Softmax
        exp_scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
        attention_weights = exp_scores / (
            exp_scores.sum(axis=-1, keepdims=True) + 1e-8
        )

        # Output
        output = attention_weights @ values  # (L, d)
        output = output @ self.W_o

        return output, attention_weights

    def feed_forward(self, x: np.ndarray) -> np.ndarray:
        """Apply position-wise feed-forward network."""
        hidden = np.maximum(0, x @ self.W_ff1 + self.b_ff1)  # ReLU
        return hidden @ self.W_ff2 + self.b_ff2

    def forward(
        self, hidden: np.ndarray, memory: np.ndarray
    ) -> tuple[np.ndarray, np.ndarray]:
        """
        Forward pass for one layer.

        Args:
            hidden: Current segment hidden states
            memory: Cached memory from previous segment

        Returns:
            output: New hidden states
            new_memory: States to cache for next segment
        """
        # Self-attention with memory
        attn_out, attn_weights = self.relative_attention(hidden, memory)
        hidden = self.layer_norm(
            hidden + attn_out, self.ln1_gamma, self.ln1_beta
        )

        # Feed-forward
        ff_out = self.feed_forward(hidden)
        output = self.layer_norm(hidden + ff_out, self.ln2_gamma, self.ln2_beta)

        return output, attn_weights
In[21]:
Code
# Demonstrate the layer
layer = TransformerXLLayer(d_model=64, n_heads=4, d_ff=256, max_rel_dist=32)

# Simulate processing two segments
segment_length = 8
hidden_dim = 64

# Segment 1: no memory yet
np.random.seed(123)
segment1_input = np.random.randn(segment_length, hidden_dim) * 0.5
segment1_output, attn1 = layer.forward(segment1_input, memory=None)

# Segment 2: use segment 1's output as memory
segment2_input = np.random.randn(segment_length, hidden_dim) * 0.5
segment2_output, attn2 = layer.forward(segment2_input, memory=segment1_output)
Out[22]:
Console
Transformer-XL Layer Processing

Segment 1 (no memory):
  Input shape: (8, 64)
  Output shape: (8, 64)
  Attention shape: (8, 8)

Segment 2 (with cached memory from segment 1):
  Input shape: (8, 64)
  Memory shape: (8, 64)
  Output shape: (8, 64)
  Attention shape: (8, 16)
  Attention to memory: 5.312
  Attention to current: 2.688

The attention shape for segment 2 is (8, 16), reflecting 8 query positions attending to 16 key positions (8 cached + 8 current). The total attention sums show how much of the model's attention budget goes to memory versus the current segment. A roughly balanced split indicates the model is actively using both sources of context.

Out[23]:
Visualization
Attention pattern when processing segment 2 with segment 1's cached states as memory. Current segment tokens can attend to all cached memory positions plus their own causal context.
Attention pattern when processing segment 2 with segment 1's cached states as memory. Current segment tokens can attend to all cached memory positions plus their own causal context.
Out[24]:
Visualization
Attention distribution per position showing how much attention each query position allocates to memory versus current segment. Early positions rely more on memory due to limited local context.
Attention distribution per position showing how much attention each query position allocates to memory versus current segment. Early positions rely more on memory due to limited local context.

The visualization reveals how attention is distributed between memory and the current segment. Early positions in the current segment allocate substantial attention to memory because they have limited local context. Later positions can attend more to the growing local context while still accessing memory for longer-range dependencies.

Evaluation: Comparing Context Approaches

How does Transformer-XL's approach compare to other long-context methods? We can evaluate on a synthetic task that explicitly requires long-range dependencies: the copying task. The model must reproduce a sequence of tokens after a long delay filled with noise.

In[25]:
Code
def create_copying_task(
    seq_length: int, copy_length: int, delay_length: int, vocab_size: int = 10
) -> tuple[np.ndarray, np.ndarray]:
    """
    Create a copying task instance.

    The input consists of:
    1. A sequence of tokens to remember (copy_length tokens)
    2. A delay period filled with blanks (delay_length tokens)
    3. A signal token indicating the model should start reproducing

    The target is the original sequence of tokens.
    """
    np.random.seed(42)

    # Tokens to copy (1 to vocab_size-2, reserve 0 for blank, vocab_size-1 for signal)
    to_copy = np.random.randint(1, vocab_size - 1, size=copy_length)

    # Build input sequence
    blank_token = 0
    signal_token = vocab_size - 1

    input_seq = np.concatenate(
        [
            to_copy,
            np.full(delay_length, blank_token),
            [signal_token],
            np.full(
                copy_length - 1, blank_token
            ),  # Positions where model outputs
        ]
    )

    # Target: just the copied tokens at the end
    target = np.concatenate(
        [
            np.full(copy_length + delay_length + 1, -1),  # -1 = ignore
            to_copy[:-1],  # Predict each token in the copy
        ]
    )

    return input_seq, target, to_copy


# Create examples with different delay lengths
delays = [50, 100, 200, 400]
copy_len = 10
examples = [create_copying_task(500, copy_len, d) for d in delays]
Out[26]:
Console
Copying Task Examples

Task: Remember the first 10 tokens, output them after a delay

Delay length: 50
  Tokens to copy: [7 4 5 7 3 8 5 5 7 2]
  Input length: 70
  Required context: 60 positions

Delay length: 100
  Tokens to copy: [7 4 5 7 3 8 5 5 7 2]
  Input length: 120
  Required context: 110 positions

Delay length: 200
  Tokens to copy: [7 4 5 7 3 8 5 5 7 2]
  Input length: 220
  Required context: 210 positions

Delay length: 400
  Tokens to copy: [7 4 5 7 3 8 5 5 7 2]
  Input length: 420
  Required context: 410 positions

This copying task becomes impossible for models without sufficient context. If the segment length is 100 and the delay is 200, a standard transformer with context fragmentation can never succeed because the tokens to copy fall outside any segment that needs to reproduce them.

Out[27]:
Visualization
Line plot comparing required context length against segment length threshold, showing failure region for standard transformers.
Required context length versus segment length for the copying task. Standard transformers fail when the required context exceeds the segment length (shaded region). Transformer-XL can succeed by accessing cached memory from previous segments, as long as the required context falls within its effective reach.

The figure illustrates the fundamental limitation of fixed context windows. As delay length increases, the required context eventually exceeds any fixed segment length. Standard transformers fail in the shaded regions. Transformer-XL extends the failure threshold by a factor proportional to the number of layers, but the memory cache size and information decay still impose practical limits.

Limitations of Recurrent Memory

While Transformer-XL's segment-level recurrence significantly extends effective context, it comes with important limitations that shape when and how the technique should be applied.

Information decay over segments. Hidden states are finite-dimensional vectors. As information propagates through multiple segments, it inevitably compresses and degrades. A fact stated in segment 1 may be perfectly preserved in segment 2's hidden states, partially preserved in segment 3, and largely lost by segment 10. Unlike attention over the full sequence, recurrence cannot perfectly preserve arbitrary information over arbitrary distances.

Out[28]:
Visualization
Line plot showing exponential decay of information signal strength across segments for three different initial signal levels.
Simulated information decay across segments in recurrent memory. Information about a specific token (introduced in segment 0) degrades as it propagates through hidden states. The rate of decay depends on the token''s distinctiveness: distinctive tokens (high initial signal) persist longer than common tokens (low initial signal). After ~8 segments, even distinctive information becomes difficult to recover.

This decay means Transformer-XL works best for gradual, statistical dependencies rather than precise long-range retrieval. Language modeling benefits because most predictions depend on local context with only soft influence from distant text. Tasks requiring exact recall of distant tokens may still fail even with recurrence.

Unidirectional information flow. The recurrence mechanism flows strictly backward in time. Segment 5 can access information from segments 1-4, but segment 2 cannot access information from segment 5. This asymmetry limits bidirectional tasks. For language understanding tasks like question answering where the question appears after the context, the question cannot inform how the context is processed.

Some architectures address this with bidirectional memory or multiple passes, but these increase complexity and computation. The fundamental tradeoff between efficiency and bidirectional context remains.

Memory cache management. Storing hidden states for memory consumes GPU memory proportional to:

Memory=M×d×Nlayers×B\text{Memory} = M \times d \times N_{\text{layers}} \times B

where:

  • MM: the memory/cache length (number of tokens cached from previous segment)
  • dd: the hidden dimension of the model
  • NlayersN_{\text{layers}}: the number of transformer layers (each layer maintains its own cache)
  • BB: the batch size

For large models, this becomes substantial. A 24-layer model with hidden dimension 1024 and memory length 512 requires storing over 12 million floats per sample. Batch processing multiplies this further.

In[29]:
Code
def compute_memory_requirements(
    memory_length: int,
    hidden_dim: int,
    num_layers: int,
    batch_size: int,
    bytes_per_float: int = 4,  # FP32
) -> dict:
    """
    Compute memory requirements for Transformer-XL cache.
    """
    floats_per_layer = memory_length * hidden_dim
    total_floats = floats_per_layer * num_layers * batch_size
    total_bytes = total_floats * bytes_per_float
    total_mb = total_bytes / (1024 * 1024)
    total_gb = total_mb / 1024

    return {
        "memory_length": memory_length,
        "hidden_dim": hidden_dim,
        "num_layers": num_layers,
        "batch_size": batch_size,
        "total_floats": total_floats,
        "memory_mb": total_mb,
        "memory_gb": total_gb,
    }


# Compare different configurations
configs = [
    (512, 768, 12, 8),  # BERT-base scale
    (512, 1024, 24, 8),  # BERT-large scale
    (1024, 1024, 24, 8),  # Larger memory
    (2048, 2048, 36, 4),  # GPT-2 XL scale
]

memory_stats = [compute_memory_requirements(*c) for c in configs]
Out[30]:
Console
Memory Cache Requirements for Transformer-XL

Config                             Floats     Memory (MB)  Memory (GB)
----------------------------------------------------------------------
512m × 768d × 12L × 8b         37,748,736           144.0         0.14
512m × 1024d × 24L × 8b       100,663,296           384.0         0.38
1024m × 1024d × 24L × 8b      201,326,592           768.0         0.75
2048m × 2048d × 36L × 4b      603,979,776          2304.0         2.25

The memory requirements grow quickly with model size. A GPT-2 XL scale configuration with 2048-token memory already requires over 2 GB just for the cache, not counting the model weights or activations. This overhead becomes a significant factor when deploying recurrent memory models on resource-constrained hardware.

Out[31]:
Visualization
Memory cache requirements scale linearly with memory length. With fixed model parameters (d=1024, 24 layers, batch size 8), doubling the memory length doubles the cache size.
Memory cache requirements scale linearly with memory length. With fixed model parameters (d=1024, 24 layers, batch size 8), doubling the memory length doubles the cache size.
Out[32]:
Visualization
Memory cache requirements grow rapidly with model size. As configurations scale from Small to XL, the combined effect of larger hidden dimensions, more layers, and longer memory leads to substantial cache overhead.
Memory cache requirements grow rapidly with model size. As configurations scale from Small to XL, the combined effect of larger hidden dimensions, more layers, and longer memory leads to substantial cache overhead.

Training and inference divergence. During training, the memory cache is populated from training data, maintaining realistic statistics. During inference on new text, the cache may be empty at the start, creating a "cold start" problem where initial segments lack the memory context the model learned to expect. Some implementations address this by processing a warmup prefix that isn't used for generation, but this adds latency.

Gradient computation complexity. Although the cached hidden states don't receive gradients (the StopGrad operation), backpropagation still needs to flow through the attention computation over the extended sequence. This slightly increases training complexity compared to pure segment-independent processing, though it's far less expensive than full sequence backpropagation.

When to Use Recurrent Memory

Transformer-XL's approach shines in specific scenarios:

  • Language modeling on long documents where sequential processing is natural and dependencies decay with distance
  • Streaming applications where text arrives incrementally and processing must be online
  • Memory-constrained settings where full attention over long sequences is infeasible
  • Tasks with local structure where most dependencies are nearby with occasional longer-range influence

The technique is less suitable for:

  • Bidirectional understanding tasks requiring simultaneous access to all context
  • Exact long-range retrieval where specific facts must be preserved precisely across many segments
  • Tasks with unpredictable dependency patterns where important information could appear anywhere

Modern alternatives like FlashAttention and sparse attention patterns have reduced the cost of longer context windows, somewhat diminishing the need for recurrent approaches. However, when truly long sequences must be processed incrementally, segment-level recurrence remains a powerful tool.

Summary

This chapter explored Transformer-XL's recurrent memory mechanism, which extends effective context beyond the fixed segment length by caching and reusing hidden states from previous segments.

Context fragmentation, where fixed-length segment processing breaks cross-segment dependencies, fundamentally limits standard transformers. Transformer-XL addresses this by caching the hidden states from the previous segment and concatenating them with the current segment's keys and values. Queries still come only from the current segment, creating an asymmetric attention pattern that enables information flow from past to present.

The recurrence mechanism requires relative position encodings because absolute positions would be ambiguous across segments. Transformer-XL redesigns the attention score computation to depend on the relative distance between query and key positions rather than their absolute locations. This involves replacing the position-dependent terms in attention with learnable global biases and relative position encoding vectors.

The effective context length grows with network depth. Information propagates one segment further back at each layer, so an NN-layer model has a theoretical reach of NN segments beyond the directly attended memory. In practice, information decay limits this reach, but the extension is still substantial.

Key implementation considerations include:

  • Memory cache storage scales with memory length, hidden dimension, number of layers, and batch size
  • The StopGrad operation prevents backpropagation through cached states, limiting training signal for long-range learning
  • Cold start at inference time may require warmup prefixes to populate the cache
  • Unidirectional information flow limits applicability to bidirectional tasks

Recurrent memory represents a principled approach to the long-context problem: accept that we cannot attend to everything at once, but ensure that information can flow across the boundaries we impose. While modern advances in attention efficiency have expanded what "at once" can mean, the core insight, that hidden states can carry forward context without explicit attention, remains valuable for streaming and memory-efficient processing of long sequences.

Key Parameters

When implementing Transformer-XL or similar recurrent memory mechanisms, the following parameters have the greatest impact on model behavior:

  • segment_length: The number of tokens processed in each forward pass. Larger segments capture more local context but increase memory usage quadratically (due to attention). Typical values range from 128 to 512 tokens.

  • memory_length: The number of tokens cached from the previous segment. Usually set equal to segment_length, but can be larger to extend context reach at the cost of increased memory and computation.

  • num_layers: Deeper networks extend effective context linearly. A 24-layer model can theoretically access 24x more context than a single layer, though information decay limits practical gains.

  • d_model: The hidden dimension affects both model capacity and memory requirements. Cache memory scales linearly with this parameter. Common values are 768 (BERT-base) to 1024 (GPT-2).

  • max_rel_dist: The maximum relative distance for position encodings. Should be at least segment_length + memory_length to cover all possible query-key distances. Setting this too small causes position information to saturate for distant tokens.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about Transformer-XL and recurrent memory mechanisms.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{recurrentmemoryextendingtransformercontextwithsegmentlevelstatecaching, author = {Michael Brenndoerfer}, title = {Recurrent Memory: Extending Transformer Context with Segment-Level State Caching}, year = {2025}, url = {https://mbrenndoerfer.com/writing/recurrent-memory-transformer-xl-segment-recurrence}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). Recurrent Memory: Extending Transformer Context with Segment-Level State Caching. Retrieved from https://mbrenndoerfer.com/writing/recurrent-memory-transformer-xl-segment-recurrence
MLAAcademic
Michael Brenndoerfer. "Recurrent Memory: Extending Transformer Context with Segment-Level State Caching." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/recurrent-memory-transformer-xl-segment-recurrence>.
CHICAGOAcademic
Michael Brenndoerfer. "Recurrent Memory: Extending Transformer Context with Segment-Level State Caching." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/recurrent-memory-transformer-xl-segment-recurrence.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Recurrent Memory: Extending Transformer Context with Segment-Level State Caching'. Available at: https://mbrenndoerfer.com/writing/recurrent-memory-transformer-xl-segment-recurrence (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). Recurrent Memory: Extending Transformer Context with Segment-Level State Caching. https://mbrenndoerfer.com/writing/recurrent-memory-transformer-xl-segment-recurrence
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free