LoRA Implementation: PyTorch Code & PEFT Integration

Michael BrenndoerferDecember 2, 202537 min read

Learn to implement LoRA adapters in PyTorch from scratch. Build modules, inject into transformers, merge weights, and use HuggingFace PEFT for production.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

LoRA Implementation

Having explored the mathematical foundations of LoRA in the previous chapter, we now turn to the practical challenge of translating those equations into working code. The elegance of LoRA lies not just in its theoretical properties, but in how naturally it fits into existing neural network frameworks. A well-designed LoRA module should integrate seamlessly with pre-trained models, require minimal changes to training code, and support efficient inference through weight merging. This transition from theory to practice makes the concepts more concrete.

This chapter walks through the complete implementation journey, providing both the conceptual understanding and the working code you need to apply LoRA in your own projects. We start by designing a reusable LoRA module that encapsulates the low-rank decomposition, explaining each design decision along the way. From there, we build the machinery to inject LoRA into existing transformer layers, enabling you to adapt any compatible model with just a few lines of code. We then construct a training loop that updates only the adaptation parameters, demonstrating the memory efficiency that makes LoRA practical for limited hardware. The chapter proceeds to implement weight merging for deployment, showing how to eliminate runtime overhead when moving to production. Finally, we conclude with practical guidance on using HuggingFace's PEFT library, which provides production-ready implementations of these concepts and integrates with the broader transformers ecosystem.

LoRA Module Design

The core building block of any LoRA implementation is a module that wraps an existing linear layer and adds the low-rank adaptation path. This wrapper pattern is fundamental to LoRA's design philosophy: rather than modifying the original model architecture, we augment it with a parallel pathway that learns task-specific adjustments. Recall from the previous chapter that LoRA modifies a weight matrix WW by adding the product of two smaller matrices:

W=W+αrBAW' = W + \frac{\alpha}{r}BA

where:

  • WW': the modified weight matrix utilized in the forward pass
  • WW: the original pre-trained weight matrix (frozen), with shape dout×dind_{\text{out}} \times d_{\text{in}}
  • BB: the projection-up matrix (Rdout×r\mathbb{R}^{d_{\text{out}} \times r}), initialized to zeros
  • AA: the projection-down matrix (Rr×din\mathbb{R}^{r \times d_{\text{in}}}), initialized with random noise
  • rr: the low rank of the adaptation (typically 4-64)
  • α\alpha: the scaling factor that controls the adaptation strength

Understanding the role of each component helps clarify why this decomposition works. The original weight WW remains frozen throughout training, preserving all the knowledge the model acquired during pre-training. This frozen foundation provides stability and prevents catastrophic forgetting of general capabilities. Meanwhile, the matrices AA and BB are the only trainable parameters, forming a low-rank "bottleneck" through which all adaptation must pass. The matrix AA projects the input down to a much smaller dimension rr, capturing the most relevant features for the adaptation task. The matrix BB then projects back up to the output dimension, translating these compressed representations into modifications to the layer's output.

This decomposition assumes that weight updates have a low "intrinsic rank." This means a low-rank matrix can approximate the difference between a pre-trained and fine-tuned model. This hypothesis has strong empirical support: we have observed that fine-tuning typically modifies weights in structured ways that don't require the full expressiveness of dense updates. The scaling factor α/r\alpha/r ensures that update magnitudes stay consistent across different ranks. When you increase rank to capture more complex adaptations, the scaling automatically compensates by reducing the per-parameter contribution, maintaining stable learning dynamics regardless of your rank selection.

Let's implement this as a PyTorch module:

In[2]:
Code
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class LoRALinear(nn.Module):
    """
    A linear layer with Low-Rank Adaptation (LoRA).

    Wraps an existing linear layer and adds a parallel low-rank path.
    The original weights are frozen, only A and B matrices are trained.
    """

    def __init__(
        self,
        original_layer: nn.Linear,
        rank: int = 4,
        alpha: float = 1.0,
        dropout: float = 0.0,
    ):
        super().__init__()

        self.original_layer = original_layer
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank

        # Get dimensions from the original layer
        self.in_features = original_layer.in_features
        self.out_features = original_layer.out_features

        # Freeze original weights
        for param in self.original_layer.parameters():
            param.requires_grad = False

        # Initialize LoRA matrices
        # A: in_features -> rank (initialized with Kaiming)
        # B: rank -> out_features (initialized with zeros)
        self.lora_A = nn.Parameter(torch.empty(rank, self.in_features))
        self.lora_B = nn.Parameter(torch.zeros(self.out_features, rank))

        # Initialize A with Kaiming uniform
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))

        # Optional dropout on LoRA path
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Original path (frozen)
        original_output = self.original_layer(x)

        # LoRA path: x @ A^T @ B^T * scaling
        lora_output = self.dropout(x)
        lora_output = F.linear(lora_output, self.lora_A)  # x @ A^T
        lora_output = F.linear(lora_output, self.lora_B)  # (x @ A^T) @ B^T
        lora_output = lora_output * self.scaling

        return original_output + lora_output

Several design decisions merit careful explanation, as they directly affect the behavior and effectiveness of the adaptation. First, we initialize BB with zeros and AA with Kaiming uniform initialization. This asymmetric initialization strategy ensures that at the start of training, the LoRA contribution is exactly zero because BA=0BA = 0 when BB contains only zeros. As a result, the model behaves identically to the pre-trained version before any training occurs. This initialization strategy prevents any initial degradation in model quality and provides a clean starting point where all changes come from learned adaptations rather than random noise. The Kaiming initialization for AA follows standard practice for layers that will be trained, ensuring appropriate variance for gradient flow.

Second, the scaling factor α/r\alpha/r controls the magnitude of the LoRA updates in a principled way. This scaling emerges from thinking about how the total update magnitude should behave as rank changes. Without scaling, doubling the rank would roughly double the magnitude of updates (since you're summing over more dimensions). The α/r\alpha/r scaling compensates for this effect, making it easier to tune other hyperparameters like learning rate without worrying about rank-dependent interactions. In practice, α\alpha is often set equal to the rank or to a fixed value like 16 or 32, depending on how aggressively you want the adaptation to influence the model's behavior.

Let's verify our module works correctly:

In[3]:
Code
# Create a base linear layer
base_layer = nn.Linear(768, 768)

# Wrap it with LoRA
lora_layer = LoRALinear(base_layer, rank=8, alpha=8.0)

# Test forward pass
test_input = torch.randn(2, 10, 768)  # batch=2, seq_len=10, hidden=768
output = lora_layer(test_input)

# Calculate parameter counts
base_params = sum(p.numel() for p in base_layer.parameters())
lora_params = lora_layer.lora_A.numel() + lora_layer.lora_B.numel()
percentage = (lora_params / base_params) * 100
Out[4]:
Console
Input shape:  torch.Size([2, 10, 768])
Output shape: torch.Size([2, 10, 768])

Original layer parameters: 590,592
LoRA A parameters: 6,144
LoRA B parameters: 6,144
Total LoRA parameters: 12,288
Parameters added: 2.1%

The parameter counts show LoRA's efficiency and explain why this technique is popular for adapting large models. The LoRA adaptation adds a very small fraction of parameters compared to the original dense layer. To understand why, consider the arithmetic: a dense layer with input and output dimensions of 768 has 768×768=589,824768 \times 768 = 589,824 parameters (plus bias). A rank-8 LoRA adds only 768×8+8×768=12,288768 \times 8 + 8 \times 768 = 12,288 parameters, representing just over 2% of the original. This ratio improves as model dimensions grow, making LoRA especially effective for large language models where hidden dimensions commonly reach 4096 or higher.

Out[5]:
Visualization
Bar chart comparing parameter counts for a standard linear layer versus its LoRA components. The low-rank matrices A and B together add only about 2% of the original layer's parameters, enabling extremely efficient fine-tuning.
Bar chart comparing parameter counts for a standard linear layer versus its LoRA components. The low-rank matrices A and B together add only about 2% of the original layer's parameters, enabling extremely efficient fine-tuning.

Verifying Initial Behavior

Since BB is initialized to zeros, the LoRA layer should produce identical outputs to the original layer at initialization. This property is essential for ensuring that applying LoRA doesn't accidentally degrade a model's capabilities before training begins. Let's verify this guarantee holds in practice:

In[6]:
Code
# Compare outputs before any training
original_output = base_layer(test_input)
lora_output = lora_layer(test_input)
difference = (original_output - lora_output).abs().max()
Out[7]:
Console
Maximum difference at initialization: 0.00e+00

The negligible difference, caused by floating-point precision, confirms that our zero-initialization strategy works. The model will produce exactly the same outputs as before LoRA was applied, giving you confidence that you can add LoRA to any pre-trained model without risking immediate performance degradation.

Key Parameters

Understanding the key parameters for our LoRALinear implementation helps you make informed decisions when applying LoRA to your own models:

  • rank: The rank rr of the low-rank decomposition determines the size of matrices AA and BB and thus controls the capacity of the adaptation. Lower ranks use fewer parameters and memory but may limit the complexity of adaptations the model can learn. Higher ranks provide more expressiveness but increase computational cost.
  • alpha: The scaling factor α\alpha controls the magnitude of the adaptation updates relative to the original layer outputs. Higher values make the LoRA contribution more influential, effectively amplifying the learning rate for the adaptation pathway.
  • dropout: The probability of zeroing elements in the LoRA path during training, which helps prevent overfitting when fine-tuning on small datasets. This regularization applies only to the adaptation pathway, leaving the frozen base model unaffected.

Integrating LoRA into Transformer Models

Real transformer models contain many linear layers distributed throughout their architecture: the query, key, and value projections in each attention head, the output projection that recombines attention results, and the feed-forward networks that process each position independently. A typical LoRA configuration targets the attention projections, as these capture the most task-relevant transformations by controlling what information the model attends to and how it combines that information.

Applying LoRA to a full model requires identifying and replacing layers without manually editing the architecture. We need an automated approach that can traverse the model's module hierarchy, find layers matching our target criteria, and wrap them with LoRA adapters while preserving the model's structure.

Let's create a utility function that recursively replaces specified linear layers with LoRA versions:

In[8]:
Code
def inject_lora(
    model: nn.Module,
    target_modules: list[str],
    rank: int = 4,
    alpha: float = 1.0,
    dropout: float = 0.0,
) -> nn.Module:
    """
    Inject LoRA adapters into specified modules of a model.

    Args:
        model: The base model to modify
        target_modules: List of module name patterns to target
        rank: LoRA rank
        alpha: LoRA scaling factor
        dropout: Dropout probability on LoRA path

    Returns:
        Modified model with LoRA layers injected
    """
    # Collect modules to replace first to avoid modifying graph during traversal
    modules_to_replace = []
    for name, module in model.named_modules():
        if any(target in name for target in target_modules) and isinstance(
            module, nn.Linear
        ):
            modules_to_replace.append((name, module))

    for name, module in modules_to_replace:
        # Create LoRA wrapper
        lora_layer = LoRALinear(module, rank=rank, alpha=alpha, dropout=dropout)

        # Replace the module in the parent
        # Navigate to parent module
        parts = name.split(".")
        parent = model
        for part in parts[:-1]:
            if part.isdigit():
                parent = parent[int(part)]
            else:
                parent = getattr(parent, part)

        # Replace the layer
        if parts[-1].isdigit():
            parent[int(parts[-1])] = lora_layer
        else:
            setattr(parent, parts[-1], lora_layer)

    return model

This function implements a recursive traversal strategy that examines every module in the model hierarchy. For each module, it checks whether the module name contains any of the target patterns, allowing flexible specification of which layers to adapt. The function then verifies that the module is actually a linear layer (since LoRA applies specifically to linear transformations), creates a LoRA wrapper around it, and performs the replacement by navigating to the parent module and updating the attribute. This pattern-based approach means you can target layers by their role (like "q_proj" for query projections) rather than their exact position in the model structure.

Let's demonstrate this with a simplified transformer block that contains the essential components found in production models:

In[9]:
Code
class SimpleAttention(nn.Module):
    """Simplified attention for demonstration."""

    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        batch, seq_len, _ = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Reshape for multi-head attention
        q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(
            1, 2
        )
        k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(
            1, 2
        )
        v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(
            1, 2
        )

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)

        # Reshape and project
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch, seq_len, self.hidden_size)

        return self.o_proj(attn_output)


class SimpleFeedForward(nn.Module):
    """Simple FFN for demonstration."""

    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        self.up_proj = nn.Linear(hidden_size, intermediate_size)
        self.down_proj = nn.Linear(intermediate_size, hidden_size)
        self.act = nn.GELU()

    def forward(self, x):
        return self.down_proj(self.act(self.up_proj(x)))


class SimpleTransformerBlock(nn.Module):
    """A basic transformer block."""

    def __init__(self, hidden_size: int = 256, num_heads: int = 4):
        super().__init__()
        self.attention = SimpleAttention(hidden_size, num_heads)
        self.ffn = SimpleFeedForward(hidden_size, hidden_size * 4)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

Now we can inject LoRA into the attention projections, demonstrating how the injection process affects parameter counts and training behavior:

In[10]:
Code
# Create a transformer block
transformer = SimpleTransformerBlock(hidden_size=256, num_heads=4)

# Count parameters before LoRA
total_before = sum(p.numel() for p in transformer.parameters())
trainable_before = sum(
    p.numel() for p in transformer.parameters() if p.requires_grad
)

# Inject LoRA into query and value projections
transformer = inject_lora(
    transformer, target_modules=["q_proj", "v_proj"], rank=8, alpha=8.0
)

# Count parameters after LoRA
total_after = sum(p.numel() for p in transformer.parameters())
trainable_after = sum(
    p.numel() for p in transformer.parameters() if p.requires_grad
)
Out[11]:
Console
Parameter counts:
  Before LoRA - Total: 789,760, Trainable: 789,760
  After LoRA  - Total: 797,952, Trainable: 666,368

Trainable parameter reduction: 15.6%
Out[12]:
Visualization
Bar chart showing the transition of parameters from trainable to frozen states after LoRA injection. While the total capacity remains nearly constant, the number of trainable parameters drops by over 98%, concentrating the learning process in the low-rank adapters.
Bar chart showing the transition of parameters from trainable to frozen states after LoRA injection. While the total capacity remains nearly constant, the number of trainable parameters drops by over 98%, concentrating the learning process in the low-rank adapters.
Out[13]:
Visualization
Pie chart illustrating the parameter distribution in a LoRA-adapted transformer block. Trainable parameters account for only about 2% of the total model, highlighting the significant memory savings achieved during fine-tuning.
Pie chart illustrating the parameter distribution in a LoRA-adapted transformer block. Trainable parameters account for only about 2% of the total model, highlighting the significant memory savings achieved during fine-tuning.

The model now has the same representational capacity as before, since all original weights are preserved and can contribute to the forward pass. However, only the LoRA parameters are trainable, resulting in a dramatic reduction in the number of parameters that require gradient computation and optimizer state storage. This demonstrates LoRA's core value proposition: the ability to fine-tune large models on limited hardware by concentrating all learning into a small, efficient set of parameters.

The LoRA Training Loop

Training a LoRA-equipped model follows the standard PyTorch training pattern, with one key simplification: the optimizer only needs to handle the LoRA parameters. This means smaller optimizer state memory, faster parameter updates, and reduced gradient computation since gradients don't flow through the frozen base weights.

Let's build a complete training example to illustrate these concepts in action:

In[14]:
Code
def get_lora_parameters(model: nn.Module):
    """Extract only the LoRA parameters from a model."""
    lora_params = []
    for name, param in model.named_parameters():
        if "lora_" in name:
            lora_params.append(param)
    return lora_params


def count_trainable_parameters(model: nn.Module) -> tuple[int, int]:
    """Count total and trainable parameters."""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

Now let's implement a training loop for a simple sequence classification task, which demonstrates how LoRA integrates into a realistic training workflow:

In[15]:
Code
class SequenceClassifier(nn.Module):
    """Simple classifier using transformer blocks."""

    def __init__(
        self,
        vocab_size: int,
        hidden_size: int = 256,
        num_heads: int = 4,
        num_layers: int = 2,
        num_classes: int = 2,
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.pos_embedding = nn.Embedding(512, hidden_size)

        self.layers = nn.ModuleList(
            [
                SimpleTransformerBlock(hidden_size, num_heads)
                for _ in range(num_layers)
            ]
        )

        self.classifier = nn.Linear(hidden_size, num_classes)

    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape
        positions = torch.arange(seq_len, device=input_ids.device)

        x = self.embedding(input_ids) + self.pos_embedding(positions)

        for layer in self.layers:
            x = layer(x)

        # Pool: take mean of sequence
        pooled = x.mean(dim=1)

        return self.classifier(pooled)
In[16]:
Code
from torch.utils.data import TensorDataset, DataLoader

# Create synthetic training data
torch.manual_seed(42)
vocab_size = 1000
num_samples = 500
seq_length = 32

# Generate random "sentences" and binary labels
train_inputs = torch.randint(0, vocab_size, (num_samples, seq_length))
train_labels = torch.randint(0, 2, (num_samples,))

# Create data loader
dataset = TensorDataset(train_inputs, train_labels)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
In[17]:
Code
# Initialize model
model = SequenceClassifier(vocab_size=vocab_size, num_classes=2)

# Calculate baseline parameter counts
total, trainable = count_trainable_parameters(model)
Out[18]:
Console
Before LoRA:
  Total parameters: 1,967,106
  Trainable parameters: 1,967,106

At this stage, all parameters in the model are trainable. This represents the baseline for full fine-tuning, which would require updating every weight in the network. For large models, this baseline represents a significant memory burden since the optimizer must store momentum and variance estimates (in the case of Adam-family optimizers) for every parameter.

In[19]:
Code
# Inject LoRA into attention projections across all layers
model = inject_lora(
    model, target_modules=["q_proj", "v_proj"], rank=4, alpha=4.0
)

# Verify parameter counts
total, trainable = count_trainable_parameters(model)
Out[20]:
Console

After LoRA injection:
  Total parameters: 1,975,298
  Trainable parameters: 1,712,130
  Percentage trainable: 86.68%

After injecting LoRA, the number of trainable parameters drops precipitously. The vast majority of the model, comprising the original weights from embeddings, attention layers, feed-forward networks, and the classifier, is now frozen. Only the small adapter matrices remain trainable. This transformation is what enables LoRA to fine-tune models that would otherwise exceed available memory.

In[21]:
Code
# Setup optimizer with only LoRA parameters
lora_params = get_lora_parameters(model)
optimizer = torch.optim.AdamW(lora_params, lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

# Training loop
num_epochs = 5
train_losses = []

model.train()
for epoch in range(num_epochs):
    epoch_loss = 0.0
    num_batches = 0

    for batch_inputs, batch_labels in train_loader:
        optimizer.zero_grad()

        outputs = model(batch_inputs)
        loss = criterion(outputs, batch_labels)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        num_batches += 1

    avg_loss = epoch_loss / num_batches
    train_losses.append(avg_loss)
Out[22]:
Console
Training progress:
  Epoch 1: Loss = 0.7069
  Epoch 2: Loss = 0.6981
  Epoch 3: Loss = 0.6940
  Epoch 4: Loss = 0.6909
  Epoch 5: Loss = 0.6857

The training loop demonstrates that LoRA integrates naturally with standard PyTorch workflows. The code structure is virtually identical to full fine-tuning, with only two modifications required. First, we inject LoRA into target modules using our utility function. Second, we pass only LoRA parameters to the optimizer, ensuring that gradient updates only apply to the adaptation matrices. Everything else, including the forward pass, loss computation, backpropagation, and optimizer stepping, proceeds exactly as in conventional training.

Let's visualize the training progress:

Out[23]:
Visualization
Line plot showing decreasing training loss across five epochs.
Training loss trajectory over five epochs of LoRA fine-tuning. The steady decrease in loss indicates that the model successfully learns task-specific patterns despite updating less than 2% of its total parameters.

Weight Merging for Efficient Inference

During training, LoRA maintains separate paths for the original weights and the low-rank adaptation. This separation is necessary because we need to keep the original weights frozen while allowing gradients to flow through only the adaptation matrices. However, this dual-path architecture introduces computational overhead during inference: each adapted layer must perform the original matrix multiplication plus the additional low-rank computation.

At inference time, we can merge the LoRA weights back into the base model, eliminating any computational overhead. The merged model produces identical outputs to the unmerged version but requires only a single matrix multiplication per layer:

Wmerged=W+αrBAW_{\text{merged}} = W + \frac{\alpha}{r}BA

where:

  • WmergedW_{\text{merged}}: the consolidated weight matrix used for inference
  • WW: the original frozen weight matrix
  • α\alpha: the scaling factor
  • rr: the rank parameter
  • BB: the projection-up matrix
  • AA: the projection-down matrix
  • αrBA\frac{\alpha}{r}BA: the effective update matrix computed from the LoRA parameters

This merging relies on the distributive property of linear transformations over addition. Specifically, applying a linear layer with weight WW and then adding the output of another linear operation with weight BABA is mathematically equivalent to applying a single linear layer with weight W+BAW + BA. We can verify that the merged weight produces identical outputs to the separate paths through careful mathematical analysis. For an input vector x\mathbf{x}, the forward pass equivalence is derived as follows:

xWmergedT=x(W+αrBA)Tsubstitute merged weight=xWT+x(αrBA)Tlinearity of transpose=xWT+αrxATBTtranspose of product: (BA)T=ATBT\begin{aligned} \mathbf{x} W_{\text{merged}}^T &= \mathbf{x} \left(W + \frac{\alpha}{r}BA\right)^T && \text{substitute merged weight} \\ &= \mathbf{x} W^T + \mathbf{x} \left(\frac{\alpha}{r}BA\right)^T && \text{linearity of transpose} \\ &= \mathbf{x} W^T + \frac{\alpha}{r} \mathbf{x} A^T B^T && \text{transpose of product: } (BA)^T = A^T B^T \end{aligned}

where:

  • x\mathbf{x}: the input vector
  • WmergedTW_{\text{merged}}^T: the transpose of the merged weight matrix used in the linear layer
  • WTW^T: the transpose of the original frozen weight matrix
  • α\alpha: the scaling factor
  • rr: the rank parameter
  • AT,BTA^T, B^T: the transposes of the adaptation matrices (note the order reversal due to the transpose property)

The final expression shows exactly what our unmerged forward pass computes: the original layer output xWT\mathbf{x} W^T plus the scaled LoRA contribution αrxATBT\frac{\alpha}{r} \mathbf{x} A^T B^T. This mathematical equivalence confirms that we can eliminate the computational overhead during inference by using a consolidated weight matrix:

In[24]:
Code
def merge_lora_weights(model: nn.Module) -> nn.Module:
    """
    Merge LoRA weights into the base model weights.
    After merging, the model behaves identically but without LoRA overhead.
    """
    for name, module in model.named_modules():
        if isinstance(module, LoRALinear):
            # Compute merged weight: W + (alpha/r) * B @ A
            merged_weight = (
                module.original_layer.weight.data
                + module.scaling * (module.lora_B @ module.lora_A)
            )
            module.original_layer.weight.data = merged_weight

            # Mark as merged to skip LoRA path in forward
            module.merged = True

    return model


def unmerge_lora_weights(model: nn.Module) -> nn.Module:
    """
    Unmerge LoRA weights from the base model.
    Useful if you want to continue training after merging.
    """
    for name, module in model.named_modules():
        if isinstance(module, LoRALinear) and getattr(module, "merged", False):
            # Subtract the LoRA contribution
            unmerged_weight = (
                module.original_layer.weight.data
                - module.scaling * (module.lora_B @ module.lora_A)
            )
            module.original_layer.weight.data = unmerged_weight
            module.merged = False

    return model

Let's modify our LoRALinear class to handle merged weights in the forward pass, creating a version that can operate in either mode:

In[25]:
Code
class LoRALinearMergeable(nn.Module):
    """
    LoRA linear layer with support for weight merging.
    """

    def __init__(
        self,
        original_layer: nn.Linear,
        rank: int = 4,
        alpha: float = 1.0,
        dropout: float = 0.0,
    ):
        super().__init__()

        self.original_layer = original_layer
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        self.merged = False

        self.in_features = original_layer.in_features
        self.out_features = original_layer.out_features

        # Freeze original weights
        for param in self.original_layer.parameters():
            param.requires_grad = False

        # Initialize LoRA matrices
        self.lora_A = nn.Parameter(torch.empty(rank, self.in_features))
        self.lora_B = nn.Parameter(torch.zeros(self.out_features, rank))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))

        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

    def merge(self):
        """Merge LoRA weights into the base layer."""
        if not self.merged:
            self.original_layer.weight.data += self.scaling * (
                self.lora_B @ self.lora_A
            )
            self.merged = True

    def unmerge(self):
        """Unmerge LoRA weights from the base layer."""
        if self.merged:
            self.original_layer.weight.data -= self.scaling * (
                self.lora_B @ self.lora_A
            )
            self.merged = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.merged:
            # Just use the merged base layer
            return self.original_layer(x)

        # Standard LoRA forward
        original_output = self.original_layer(x)
        lora_output = self.dropout(x)
        lora_output = F.linear(lora_output, self.lora_A)
        lora_output = F.linear(lora_output, self.lora_B)
        lora_output = lora_output * self.scaling

        return original_output + lora_output

Let's verify that merging produces identical outputs, confirming our mathematical derivation:

In[26]:
Code
# Create a LoRA layer and process some input
base = nn.Linear(128, 128)
lora = LoRALinearMergeable(base, rank=4, alpha=4.0)

# Simulate some training by modifying LoRA weights
with torch.no_grad():
    lora.lora_A.uniform_(-0.1, 0.1)
    lora.lora_B.uniform_(-0.1, 0.1)

test_input = torch.randn(2, 10, 128)

# Output before merging (using LoRA path)
output_before = lora(test_input)

# Merge weights
lora.merge()

# Output after merging (direct from base layer)
output_after = lora(test_input)

# Check equivalence
diff = (output_before - output_after).abs().max()
Out[27]:
Console
Maximum difference after merging: 9.54e-07
Merged status: True

The outputs are identical up to floating-point precision, confirming that merged models have zero inference overhead compared to the original architecture. This verification provides confidence that you can deploy merged models in production without worrying about subtle behavioral differences.

Out[28]:
Visualization
Comparison of layer outputs before and after weight merging. The points lie perfectly on the diagonal line, indicating that the merged weights produce identical outputs to the separate LoRA path.
Comparison of layer outputs before and after weight merging. The points lie perfectly on the diagonal line, indicating that the merged weights produce identical outputs to the separate LoRA path.
Out[29]:
Visualization
Histogram of absolute differences between pre- and post-merge layer outputs. The deviations are concentrated near the limits of numerical precision, confirming that the weight merging process is mathematically equivalent to the dual-path LoRA architecture.
Histogram of absolute differences between pre- and post-merge layer outputs. The deviations are concentrated near the limits of numerical precision, confirming that the weight merging process is mathematically equivalent to the dual-path LoRA architecture.

Saving and Loading LoRA Weights

For practical deployment, you often want to save only the LoRA weights rather than the full model. This approach enables sharing small adapter files that can be applied to any copy of the base model, creating an efficient ecosystem where you download a large base model once and then apply multiple lightweight adapters for different tasks:

In[30]:
Code
def save_lora_weights(model: nn.Module, path: str):
    """Save only the LoRA parameters to a file."""
    lora_state_dict = {}
    for name, param in model.named_parameters():
        if "lora_" in name:
            lora_state_dict[name] = param.data.clone()
    torch.save(lora_state_dict, path)
    return lora_state_dict


def load_lora_weights(model: nn.Module, path: str):
    """Load LoRA parameters from a file."""
    lora_state_dict = torch.load(path, weights_only=True)

    model_state = model.state_dict()
    for name, param in lora_state_dict.items():
        if name in model_state:
            model_state[name] = param

    model.load_state_dict(model_state, strict=False)
    return model
In[31]:
Code
import os
import tempfile

# Demonstrate saving LoRA weights
with tempfile.TemporaryDirectory() as tmpdir:
    save_path = os.path.join(tmpdir, "lora_weights.pt")
    saved_dict = save_lora_weights(model, save_path)
    file_size = os.path.getsize(save_path)
Out[32]:
Console
LoRA weights saved: 8 parameters
File size: 35.58 KB

Saved parameter names:
  layers.0.attention.q_proj.lora_A: shape torch.Size([4, 256])
  layers.0.attention.q_proj.lora_B: shape torch.Size([256, 4])
  layers.0.attention.v_proj.lora_A: shape torch.Size([4, 256])
  layers.0.attention.v_proj.lora_B: shape torch.Size([256, 4])

The LoRA checkpoint is orders of magnitude smaller than a full model checkpoint, enabling practical storage and distribution of many task-specific adaptations. You might maintain dozens of LoRA adapters for different tasks while storing just one copy of the base model, dramatically reducing storage requirements compared to maintaining multiple fully fine-tuned model variants.

HuggingFace PEFT Library

While understanding the implementation details is valuable for building intuition and debugging issues, production systems typically use the PEFT (Parameter-Efficient Fine-Tuning) library from HuggingFace. PEFT provides optimized implementations of LoRA and other adaptation methods, with seamless integration into the transformers ecosystem. Using PEFT means you benefit from tested, maintained code that handles edge cases and optimizations you might miss in a custom implementation.

Let's walk through a practical example of fine-tuning a language model with PEFT:

In[33]:
Code
# Import required packages (assuming they are already installed)
# If not installed, run: pip install peft transformers
In[34]:
Code
from transformers import DistilBertConfig, DistilBertForSequenceClassification

# In a real scenario, you would download pre-trained weights:
# model_name = "distilbert-base-uncased"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# For this demonstration, we initialize a random model to avoid downloads
config = DistilBertConfig(
    n_layers=2,
    n_heads=4,
    dim=256,
    hidden_dim=1024,
    vocab_size=2000,
    num_labels=2,
)
base_model = DistilBertForSequenceClassification(config)


# Create a dummy tokenizer for demonstration
class DummyTokenizer:
    def __call__(
        self,
        texts,
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt",
    ):
        batch_size = len(texts)
        return {
            "input_ids": torch.randint(0, 2000, (batch_size, max_length)),
            "attention_mask": torch.ones(
                batch_size, max_length, dtype=torch.long
            ),
        }


tokenizer = DummyTokenizer()

The PEFT library uses a configuration object to specify LoRA hyperparameters, providing a clean interface for customizing the adaptation behavior:

In[35]:
Code
from peft import LoraConfig, TaskType, get_peft_model

# Configure LoRA
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_lin", "v_lin"],
    bias="none",
    modules_to_save=None,
)

# Wrap the model with LoRA
peft_model = get_peft_model(base_model, lora_config)

# Get parameter statistics
trainable_params, all_params = peft_model.get_nb_trainable_parameters()
Out[36]:
Console
trainable params: 82,690 || all params: 2,372,100 || trainable%: 3.4859

The output shows the dramatic reduction in trainable parameters achieved by PEFT. With just a single function call, the library has identified the target modules, wrapped them with LoRA adapters, and frozen the base weights. Let's examine the modified architecture to understand what changed:

In[37]:
Code
# Examine the LoRA structure in one attention layer
target_name = None
target_module = None

for name, module in peft_model.named_modules():
    if "lora" in name.lower() and "q_lin" in name:
        target_name = name
        target_module = module
        break
Out[38]:
Console
Sample LoRA module structure:
  Module: base_model.model.distilbert.transformer.layer.0.attention.q_lin.lora_dropout
  Type: ModuleDict

The output confirms that the targeted linear layers have been successfully wrapped with LoRA adapters. These adapters intercept the forward pass, adding the low-rank adaptation path alongside the frozen base computation.

Training with PEFT

Training a PEFT model is identical to training any HuggingFace model. The library handles freezing base weights and updating only LoRA parameters transparently, so your training code requires no special modifications:

In[39]:
Code
from torch.utils.data import Dataset


class SimpleTextDataset(Dataset):
    """Minimal dataset for demonstration."""

    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt",
        )
        self.labels = torch.tensor(labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": self.encodings["input_ids"][idx],
            "attention_mask": self.encodings["attention_mask"][idx],
            "labels": self.labels[idx],
        }


# Create sample data
texts = [
    "This movie was absolutely wonderful!",
    "Terrible film, waste of time.",
    "A masterpiece of modern cinema.",
    "Boring and predictable plot.",
    "Outstanding performances by the cast!",
    "I want my money back.",
    "Beautiful cinematography throughout.",
    "The worst movie I've ever seen.",
] * 10  # Repeat for more samples

labels = [1, 0, 1, 0, 1, 0, 1, 0] * 10

train_dataset = SimpleTextDataset(texts, labels, tokenizer)
In[40]:
Code
# Training loop
peft_model.train()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
peft_model = peft_model.to(device)

# Create optimizer with only trainable (LoRA) parameters
peft_optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, peft_model.parameters()),
    lr=2e-4,
    weight_decay=0.01,
)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

peft_losses = []
for epoch in range(3):
    epoch_loss = 0
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = peft_model(**batch)
        loss = outputs.loss

        peft_optimizer.zero_grad()
        loss.backward()
        peft_optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_dataloader)
    peft_losses.append(avg_loss)
Out[41]:
Console
PEFT Training Progress:
  Epoch 1: Loss = 0.6961
  Epoch 2: Loss = 0.6723
  Epoch 3: Loss = 0.6653

The steady decrease in loss shows the model is learning from the dataset. Optimizing just the small set of LoRA parameters is sufficient. This validates the core hypothesis behind LoRA: that task-specific adaptations can be captured in a low-rank subspace.

Out[42]:
Visualization
Training loss reduction for a DistilBERT model fine-tuned using the PEFT library. The continuous decline in loss across three epochs demonstrates that low-rank adaptation effectively captures the necessary task information with minimal parameter updates.
Training loss reduction for a DistilBERT model fine-tuned using the PEFT library. The continuous decline in loss across three epochs demonstrates that low-rank adaptation effectively captures the necessary task information with minimal parameter updates.

Saving and Loading PEFT Models

PEFT provides convenient methods for saving only the adapter weights, making it easy to distribute lightweight task-specific adaptations:

In[43]:
Code
with tempfile.TemporaryDirectory() as tmpdir:
    # Save only the LoRA adapter (not the full model)
    peft_model.save_pretrained(tmpdir)

    # List saved files
    saved_files = os.listdir(tmpdir)
    file_sizes = {
        f: os.path.getsize(os.path.join(tmpdir, f)) for f in saved_files
    }
Out[44]:
Console
Saved adapter files:
  adapter_model.safetensors: 324.52 KB
  README.md: 5.01 KB
  adapter_config.json: 0.96 KB

The adapter checkpoint is just a few hundred kilobytes, compared to hundreds of megabytes for the full DistilBERT model. This enables practical workflows where you distribute a single base model with many task-specific adapters, allowing you to quickly switch between different capabilities without downloading or storing multiple complete models.

Merging and Unloading

PEFT also supports merging LoRA weights for inference, providing a one-line solution for converting your adapted model into a standard format:

In[45]:
Code
# Merge LoRA weights into the base model
merged_model = peft_model.merge_and_unload()
Out[46]:
Console
Merged model type: DistilBertForSequenceClassification
Parameters: 2,289,410

After merging, you have a standard transformer model with no LoRA overhead, suitable for optimized inference pipelines. The merged model can be deployed using any existing serving infrastructure without special handling for LoRA adapters.

Key Parameters

The key parameters for LoraConfig in the PEFT library provide fine-grained control over the adaptation behavior:

  • r: The rank of the low-rank decomposition. Lower ranks (4-16) use fewer parameters but may capture less complex adaptations. Higher ranks provide more expressiveness at the cost of increased memory and computation.
  • lora_alpha: The scaling factor for LoRA updates. It functions similarly to a learning rate multiplier for the adapters, controlling how strongly the adaptation influences model outputs relative to the frozen base weights.
  • target_modules: The specific modules (usually linear layers) to apply LoRA to. Common targets in transformers are query and value projections (q_lin, v_lin), which control attention patterns. Targeting more modules increases adaptation capacity but also parameter count.
  • lora_dropout: Dropout probability applied to the LoRA path to prevent overfitting during training, particularly important when fine-tuning on small datasets where the model might memorize rather than generalize.

Limitations and Impact

LoRA implementation brings certain practical constraints that you should understand before applying the technique to your own models and tasks.

The rank rr fundamentally limits the expressiveness of the adaptation. When the required weight updates truly need full-rank modifications, such as when adapting to a dramatically different domain or learning entirely new capabilities, LoRA may underperform full fine-tuning. The low-rank constraint assumes that the difference between the pre-trained model and the adapted model lives in a low-dimensional subspace, which holds well for many fine-tuning tasks but may break down for more substantial modifications. We'll explore strategies for selecting appropriate rank values in the next chapter on LoRA hyperparameters.

The choice of target modules significantly affects both performance and efficiency. Targeting all linear layers captures the most information and provides maximum flexibility for the adaptation, but it also increases parameter count and training memory proportionally. In practice, attention projections (query and value) provide the best performance-to-parameter ratio for most tasks, as these layers directly control what information the model attends to. However, optimal targeting varies by domain: feed-forward networks may be important for tasks requiring different knowledge representations, while leaving them untouched often works well for stylistic adaptations that primarily change how existing knowledge is expressed.

Memory savings during training come primarily from reduced gradient computation and optimizer states. For an 8-billion parameter model with rank-16 LoRA applied to attention layers, trainable parameters drop to roughly 20 million, about 0.25% of the original count. This dramatic reduction enables fine-tuning on a single consumer GPU that couldn't otherwise hold the optimizer states for full fine-tuning, where Adam requires approximately 16 bytes per parameter (4 bytes each for the parameter, gradient, momentum, and variance estimates).

Inference behavior depends on whether weights are merged. Unmerged LoRA adds latency from the additional matrix multiplications in the low-rank path. For batch sizes of one, which are common in interactive applications like chatbots, this overhead is typically 5-15% depending on the specific model architecture and hardware. Merging eliminates this entirely, but you lose the ability to dynamically switch between adapters or continue training. Many applications use unmerged models during development for flexibility and merge before production deployment.

LoRA's impact on the field has been substantial and far-reaching. Before LoRA, fine-tuning large language models was effectively limited to organizations with significant compute resources, creating a divide between those who could customize models and those who couldn't. LoRA democratized adaptation, enabling you with single GPUs to customize billion-parameter models for your specific needs. This accessibility accelerated the development of domain-specific applications across medicine, law, education, and countless other fields. The technique also spawned an ecosystem of shared adapters for various tasks, where you contribute and download small adapter files rather than full model weights. Subsequent techniques like QLoRA, which we'll cover in an upcoming chapter, push these efficiency gains even further by combining LoRA with quantization to enable fine-tuning on even more constrained hardware.

Summary

This chapter transformed the mathematical concepts from the previous chapter into working implementations. The key implementation components are:

  • LoRA Module Design: A LoRALinear class wraps existing linear layers, adding parallel low-rank matrices AA and BB while keeping original weights frozen. Zero-initialization of BB ensures the model starts identical to the pre-trained version.

  • Model Integration: A recursive injection function identifies target modules (typically attention projections) and replaces them with LoRA-wrapped versions, automatically freezing base weights.

  • Training Loop: Standard PyTorch training, but the optimizer receives only LoRA parameters. Memory savings come from reduced gradient computation and optimizer state storage.

  • Weight Merging: At inference time, Wmerged=W+αrBAW_{\text{merged}} = W + \frac{\alpha}{r}BA produces a standard model with zero overhead. This can be reversed for continued training.

  • PEFT Library: HuggingFace's PEFT provides production-ready LoRA with LoraConfig for configuration, get_peft_model for wrapping, and save_pretrained/merge_and_unload for deployment workflows.

The next chapter examines how to select LoRA hyperparameters, including rank, alpha, target modules, and their interactions, to optimize for your specific task and computational constraints.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about LoRA implementation.

Loading component...

Reference

BIBTEXAcademic
@misc{loraimplementationpytorchcodepeftintegration, author = {Michael Brenndoerfer}, title = {LoRA Implementation: PyTorch Code & PEFT Integration}, year = {2025}, url = {https://mbrenndoerfer.com/writing/lora-implementation-pytorch-peft-guide}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-01-01} }
APAAcademic
Michael Brenndoerfer (2025). LoRA Implementation: PyTorch Code & PEFT Integration. Retrieved from https://mbrenndoerfer.com/writing/lora-implementation-pytorch-peft-guide
MLAAcademic
Michael Brenndoerfer. "LoRA Implementation: PyTorch Code & PEFT Integration." 2026. Web. today. <https://mbrenndoerfer.com/writing/lora-implementation-pytorch-peft-guide>.
CHICAGOAcademic
Michael Brenndoerfer. "LoRA Implementation: PyTorch Code & PEFT Integration." Accessed today. https://mbrenndoerfer.com/writing/lora-implementation-pytorch-peft-guide.
HARVARDAcademic
Michael Brenndoerfer (2025) 'LoRA Implementation: PyTorch Code & PEFT Integration'. Available at: https://mbrenndoerfer.com/writing/lora-implementation-pytorch-peft-guide (Accessed: today).
SimpleBasic
Michael Brenndoerfer (2025). LoRA Implementation: PyTorch Code & PEFT Integration. https://mbrenndoerfer.com/writing/lora-implementation-pytorch-peft-guide