Unsloth: How Monkey Patching Makes LLMs 2x Faster

project-watch
inference
gpu
optimization
unsloth
A guided code-reading journey through Daniel Han’s Unsloth. You will trace every function call in the Qwen3 patch — from the 3-line monkey patch idea to Triton GPU kernels — and understand WHY each decision was made. Code reading as a skill.
Published

March 11, 2026

First Break AI — Project Watch

This is a Project Watch deep dive. We study real, shipping AI projects by reading their source code. Connects to: Step 2 (inference basics), Step 3 (inference engines), Step 4 (training).


TipWhat you will be able to do after this post
  • Read and understand ~500 lines of production Python code (the Unsloth Qwen3 patch)
  • Trace a function call from from_pretrained() all the way down to a Triton GPU kernel
  • Explain what monkey patching is and why Daniel Han uses it
  • Identify the exact lines where HuggingFace execution gets rerouted to Unsloth’s fast path
  • Recognize the same operations you saw in Step 2 (RMSNorm, RoPE, KV cache, attention) — now optimized for GPU
  • Read a Triton kernel and understand what “fused” means concretely
NoteRoadmap connections
Roadmap step What you already know What this post adds
Step 2: Run a model locally RMSNorm, RoPE, attention, KV cache in C How Unsloth optimizes these same ops with fused GPU kernels
Step 3: Inference deep dive Inference engines, quantization Why frameworks like Unsloth exist alongside vLLM/llama.cpp
Step 4: Training PyTorch, fine-tuning, LoRA How Unsloth patches training paths too (cross-entropy, LoRA)

The problem before the solution

Before reading any Unsloth code, you need to understand the problem Daniel Han saw.

HuggingFace Transformers is the standard library for loading and running LLMs. It is designed for correctness and readability — not for peak GPU performance. Every forward pass goes through many small Python calls, each launching a separate GPU kernel.

Here is what a single RMSNorm looks like on the GPU when HuggingFace runs it naively:

flowchart LR
    subgraph hf ["HuggingFace default: 6 kernel launches"]
        K1["square\neach element"] --> K2["sum\nthe squares"]
        K2 --> K3["compute\nmean"]
        K3 --> K4["take\nsqrt"]
        K4 --> K5["multiply by\nreciprocal"]
        K5 --> K6["multiply by\nweight"]
    end

Each arrow is a kernel launch — the GPU reads data from memory, does a tiny operation, writes it back. Then the next kernel reads it again. Six round trips to memory for one normalization.

Now multiply this by the fact that a Qwen3 model has 28 layers, each with two RMSNorm calls (before attention and before FFN). That is 56 RMSNorm calls per token. At 6 kernel launches each, that is 336 kernel launches just for normalization — before you even count attention, RoPE, or the FFN.

Daniel’s insight: replace these paths at runtime with fused implementations that do all six steps in one kernel launch. The mechanism for doing this replacement is monkey patching.

flowchart LR
    subgraph fused ["Unsloth fused: 1 kernel launch"]
        F1["RMSNorm\nall 6 steps in one kernel\ndata stays on-chip"]
    end

That is the entire thesis: same math, fewer memory round trips, transparent to the user.

Let us now trace exactly how Daniel implements this, starting from the simplest possible version.


Lesson 1: What is monkey patching?

Monkey patching is replacing a method on a class at runtime so that all future calls to that method execute your replacement instead of the original.

In Python, every class method is just an attribute on the class object. You can overwrite it:

class Dog:
    def speak(self):
        return "Woof"

dog = Dog()
print(dog.speak())  # "Woof"

Dog.speak = lambda self: "BARK!"

print(dog.speak())  # "BARK!"

The key insight: you changed the class, not the instance. Every Dog instance — even ones created before the patch — now uses the new method.

For LLMs, the target is .forward() — the method that runs the actual computation for each layer, each attention block, and the entire model. Overwrite .forward() on the right classes, and you control how the entire model executes.

Why this matters for GPU optimization

If you write a faster version of attention, you do not need to fork HuggingFace. You do not need users to change their code. You just:

  1. Import the HuggingFace class
  2. Overwrite its .forward() with your fast version
  3. Let the user call generate() as normal

The user’s code is identical. The GPU execution is completely different.


Lesson 2: The 3-line version

The simplest real-world monkey patch on an LLM comes from TinyZero:

def apply_monkey_patch_to_qwen2():
    from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
    from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
    Qwen2FlashAttention2.forward = qwen2_flash_attn_forward

Read it line by line:

Line 1: Import the normal HuggingFace attention class. This is the class that HuggingFace would normally use when you call model.generate().

Line 2: Import a custom replacement function from a different codebase (verl). This function has the same signature as the original .forward() — it takes the same arguments and returns the same types — but internally it does something different (faster).

Line 3: Overwrite the class method. After this line, every Qwen2FlashAttention2 instance in memory — including ones inside a model that was already loaded — will run the custom function when .forward() is called.

What to notice

  • Same class. Same API. Different runtime behavior.
  • The user never sees this. They call model.generate() and get the same outputs.
  • This is only a micro patch — one method on one class. It does not touch the decoder layer, the model wrapper, the generation logic, or anything else.

Daniel Han scales this exact idea to the entire model execution stack. That is what we will trace next.


Lesson 3: Open the real file

Go to the Unsloth Qwen3 source file on GitHub. Open it in a new tab. We will trace through it together.

The imports (lines 1-118)

The file starts by importing from two places:

From HuggingFace Transformers:

from transformers.models.qwen3.modeling_qwen3 import (
    Qwen3Attention,
    Qwen3DecoderLayer,
    Qwen3Model,
    Qwen3ForCausalLM,
)

These are the classes that HuggingFace would normally use. Daniel imports them so he can overwrite their methods.

From Unsloth’s own codebase:

from .llama import (
    LlamaRotaryEmbedding,
    LlamaLinearScalingRotaryEmbedding,
    _LlamaModel_fast_forward_inference,
)

These are the replacement implementations. Notice something: Daniel reuses the Llama fast forward functions for Qwen3. This is because Qwen3 and Llama share the same basic transformer architecture. Instead of writing separate fast paths for every model family, Daniel writes one optimized path and patches many model classes into it.

The attention variants:

try:
    from transformers.models.qwen3.modeling_qwen3 import (
        Qwen3SdpaAttention,
        Qwen3FlashAttention2,
    )
except:
    Qwen3SdpaAttention = Qwen3Attention
    Qwen3FlashAttention2 = Qwen3Attention

HuggingFace has multiple attention implementations for the same model — vanilla, SDPA, and Flash Attention 2. The try/except handles older Transformers versions that may not have all three. Regardless, Daniel will patch all of them to the same fast function.

TipWhat Daniel is solving here

HuggingFace has 3 attention classes for Qwen3. Without monkey patching, you would need to fork all 3. With monkey patching, you import all 3 and point them at one optimized function. One implementation, three entry points.


Lesson 4: The pre_patch() method — where the switch happens

Scroll down to class FastQwen3Model. Inside it, find the pre_patch() static method. This is the single most important method in the file — it is where Daniel flips the switch.

class FastQwen3Model(FastLlamaModel):
    @staticmethod
    def pre_patch():
        init_name, function = patch_linear_scaling(
            model_name="Qwen3",
            rope_module=LlamaRotaryEmbedding,
            scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
            attention_module=Qwen3Attention,
        )
        if init_name is not None:
            exec(function, globals())
            Qwen3Attention.__init__ = eval(init_name)

        Qwen3Attention.forward       = Qwen3Attention_fast_forward
        Qwen3SdpaAttention.forward   = Qwen3Attention_fast_forward
        Qwen3FlashAttention2.forward = Qwen3Attention_fast_forward
        Qwen3DecoderLayer.forward    = LlamaDecoderLayer_fast_forward
        Qwen3Model.forward           = LlamaModel_fast_forward
        Qwen3ForCausalLM.forward     = CausalLM_fast_forward(
            _LlamaModel_fast_forward_inference(
                Qwen3Attention_fast_forward_inference
            )
        )
        PeftModelForCausalLM.forward = PeftModel_fast_forward
        fix_prepare_inputs_for_generation(Qwen3ForCausalLM)

        import transformers.models.qwen3.modeling_qwen3
        transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding = (
            LlamaRotaryEmbedding
        )

Read each .forward = assignment as a decision Daniel made:

Level 1 — Hot-op patch (attention)

Qwen3Attention.forward       = Qwen3Attention_fast_forward
Qwen3SdpaAttention.forward   = Qwen3Attention_fast_forward
Qwen3FlashAttention2.forward = Qwen3Attention_fast_forward

All three attention variants now route to the same fast function. The user may have configured any of these — does not matter. They all go through Daniel’s optimized path.

Level 2 — Block patch (decoder layer)

Qwen3DecoderLayer.forward = LlamaDecoderLayer_fast_forward

Each of the 28 transformer blocks in the model now runs through the fast decoder-layer path. This controls how attention and FFN are composed within each block.

Level 3 — Pipeline patch (whole model + generation)

Qwen3Model.forward       = LlamaModel_fast_forward
Qwen3ForCausalLM.forward = CausalLM_fast_forward(...)
PeftModelForCausalLM.forward = PeftModel_fast_forward

The entire model forward, the causal LM wrapper (which handles logits and loss), the PEFT/LoRA wrapper, and generation input preparation — all rerouted.

Also: the RoPE replacement

transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding = (
    LlamaRotaryEmbedding
)

Daniel does not just patch methods — he replaces the entire class that HuggingFace uses for positional embeddings. Any time HF code creates a new Qwen3RotaryEmbedding, it gets Unsloth’s version instead.

flowchart TD
    subgraph before ["Before pre_patch()"]
        HF_ATT["Qwen3Attention.forward()\nHuggingFace default"]
        HF_DEC["Qwen3DecoderLayer.forward()\nHuggingFace default"]
        HF_MOD["Qwen3Model.forward()\nHuggingFace default"]
        HF_CLM["Qwen3ForCausalLM.forward()\nHuggingFace default"]
    end
    subgraph after ["After pre_patch()"]
        US_ATT["Qwen3Attention_fast_forward\nUnsloth optimized"]
        US_DEC["LlamaDecoderLayer_fast_forward\nUnsloth optimized"]
        US_MOD["LlamaModel_fast_forward\nUnsloth optimized"]
        US_CLM["CausalLM_fast_forward\nUnsloth optimized"]
    end
    HF_ATT -.->|"replaced by"| US_ATT
    HF_DEC -.->|"replaced by"| US_DEC
    HF_MOD -.->|"replaced by"| US_MOD
    HF_CLM -.->|"replaced by"| US_CLM

TipWhat Daniel is solving here

HuggingFace has a deep call stack: CausalLM.forward() calls Model.forward(), which calls DecoderLayer.forward() 28 times, which calls Attention.forward() + FFN. By patching every level, Daniel ensures there is no escape path — no matter which layer executes, it runs Unsloth’s code.


Lesson 5: How from_pretrained() triggers the patch

Now look at FastQwen3Model.from_pretrained():

class FastQwen3Model(FastLlamaModel):
    @staticmethod
    def from_pretrained(
        model_name="Qwen/Qwen3-7B",
        max_seq_length=4096,
        load_in_4bit=True,
        ...
    ):
        return FastLlamaModel.from_pretrained(
            model_name=model_name,
            model_patcher=FastQwen3Model,
            ...
        )

The user writes:

model, tokenizer = FastQwen3Model.from_pretrained("Qwen/Qwen3-0.6B")

Internally, model_patcher=FastQwen3Model causes pre_patch() to run before the model is returned. The model loads normally from HuggingFace Hub — same weights, same architecture — but by the time the user receives the model object, every .forward() has been rerouted.

The user then calls model.generate(...) exactly as they would with stock HuggingFace. The outputs are identical. The execution path is completely different.


Lesson 6: Trace the training path — Qwen3Attention_fast_forward

Now we go inside the replacement function. This is the code that runs instead of HuggingFace’s Qwen3Attention.forward() during training and prefill.

Open the file and find def Qwen3Attention_fast_forward(...). We will trace it section by section.

Step 1: QKV projection

Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim)
K = K.view(bsz, q_len, n_kv_heads, head_dim)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)

Instead of the stock HuggingFace projection (which calls self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) separately), Unsloth uses apply_qkv which may fuse the three projections and control memory layout.

Notice the shapes: Q gets n_heads heads, K and V get n_kv_heads heads. This is Grouped Query Attention (GQA) — Qwen3 uses fewer KV heads than query heads to save memory. You saw this in Step 2.

Step 2: Q/K normalization (QKNorm)

Q = fast_rms_layernorm(self.q_norm, Q)
K = fast_rms_layernorm(self.k_norm, K)

Qwen3 introduced QKNorm — applying RMSNorm to Q and K before computing attention scores. This is the same rmsnorm() you wrote in C during Step 2:

// From Step 2's run.c
void rmsnorm(float* o, float* x, float* weight, int size) {
    float ss = 0.0f;
    for (int j = 0; j < size; j++) ss += x[j] * x[j];
    ss /= size;
    ss += 1e-5f;
    ss = 1.0f / sqrtf(ss);
    for (int j = 0; j < size; j++) o[j] = weight[j] * (ss * x[j]);
}

Same math. But fast_rms_layernorm() calls a Triton kernel that does all of this in one GPU kernel launch instead of six separate operations. We will trace that kernel in Lesson 8.

TipWhat Daniel is solving here

Stock HuggingFace calls self.q_norm(Q) which goes through PyTorch’s generic LayerNorm path — multiple small kernel launches. Daniel calls fast_rms_layernorm() which goes to a fused Triton kernel. Same normalization, fewer memory round trips.

Step 3: Fast RoPE

Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)

In Step 2 you traced the RoPE rotation loop — the cos/sin multiplication that gives each token a position-dependent rotation. Here, fast_rope_embedding() fuses the rotation for both Q and K in a single call, avoiding separate kernel launches for each.

Before applying RoPE, Unsloth dynamically extends the RoPE embedding if the sequence is longer than expected:

if position_embeddings and kv_seq_len <= position_embeddings[0].shape[0]:
    cos, sin = position_embeddings
else:
    rotary_emb = self.rotary_emb
    rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len)
    cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index)

This handles sequences that are longer than what the model was originally configured for — a practical concern when users push models beyond their default context length.

Step 4: KV cache concatenation

if past_key_value is not None:
    K = torch.cat([past_key_value[0], K], dim=2)
    V = torch.cat([past_key_value[1], V], dim=2)
past_key_value = (K, V) if use_cache else None

This is the same KV cache you saw in Step 2 — previously computed keys and values are concatenated with new ones. During the training path, this is straightforward concatenation. The inference path (Lesson 7) does something much more sophisticated.

Step 5: Attention backend dispatch

backend = (
    SDPA if attention_mask is not None
    else select_attention_backend(use_varlen)
)
attention_config = AttentionConfig(
    backend=backend,
    n_kv_heads=n_kv_heads,
    n_groups=n_groups,
    ...
)
A = run_attention(config=attention_config, context=context, Q=Q, K=K, V=V)

Instead of blindly using one attention implementation, Unsloth chooses the best backend for the current situation:

  • SDPA (Scaled Dot Product Attention) — PyTorch’s built-in fused attention, used when there is an attention mask
  • Flash Attention — for dense, causal attention without masks
  • Variable-length attention — for packed/batched sequences with different lengths

This dispatch is invisible to the user but critical for performance. Different batch configurations and sequence lengths perform best with different backends.

Step 6: Output projection

attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
attn_output = self.apply_o(self, attn_output)

The attention output is reshaped and projected through the output linear layer. apply_o may use Unsloth’s fast_linear_forward() under the hood.


Lesson 7: Trace the decode path — Qwen3Attention_fast_forward_inference

This is the separate inference function for single-token decode — the step that runs thousands of times during generation. It is the most performance-critical code in the entire file.

Find def Qwen3Attention_fast_forward_inference(...) in the source.

Why a separate decode path?

During generation, the model processes tokens one at a time. The first pass (prefill) processes the entire prompt at once — many tokens, standard batched computation. Every subsequent step (decode) processes exactly one new token, looking up all previous tokens in the KV cache.

flowchart LR
    subgraph prefill ["Prefill: process full prompt"]
        P1["512 tokens at once"]
        P2["Standard batched attention"]
        P3["Fill KV cache"]
    end
    subgraph decode ["Decode: one token at a time"]
        D1["1 token"]
        D2["Look up full KV cache"]
        D3["Repeat 100-1000x"]
    end
    prefill --> decode

Because the decode step is repeated thousands of times with the same tiny shape (1 token), it is worth optimizing separately from the general path. Every microsecond saved in one decode step gets multiplied by every generated token.

The docstring — Daniel’s own explanation

The function has an unusually detailed docstring. Read it carefully:

"""
QK^T can be computed in 4 chunks

[Q, q] @ [K, k].T where q, k are the new tokens.
[QK^T, Qk^T]
[qK^T, qk^T]

Since the attention mask wipes Qk^T, we just get
[QK^T,    0]
[qK^T, qk^T]

Since softmax is row-wise, we get
softmax([QK^T,    0])
softmax([qK^T, qk^T])

We then multiply by   [V]
                       [v]
softmax([QK^T,    0]) [softmax(QK^T)V] *
softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]

But notice * [softmax(QK^T)V] is just the last attention.
We just need to compute the last final row.
"""

This is the mathematical justification for the optimization: when generating one token at a time, you only need to compute the last row of the attention matrix. All previous rows were computed in earlier steps and their results are already in the KV cache.

Pre-allocated buffers

if do_prefill:
    self.paged_attention = torch.empty(
        (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
        dtype=dtype, device=device,
    )
    self.paged_attention_K = self.paged_attention[:, 0]
    self.paged_attention_V = self.paged_attention[:, 1]
    self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
    self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
    self.temp_QA = torch.empty(
        (2, bsz, 1, attention_size), dtype=dtype, device=device
    )
    self.temp_KV = torch.empty(
        (2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
    )
    self.RH_Q = torch.empty(
        (bsz, n_heads, 1, head_dim), dtype=dtype, device=device
    )
    self.attention = torch.empty(
        (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len),
        dtype=dtype, device=device
    )
    self.scalar = 1.0 / math_sqrt(self.head_dim)
TipWhat Daniel is solving here

In the default HuggingFace path, temporary tensors for Q, K, V, and attention scores are allocated every single decode step. That is thousands of torch.empty() calls during one generation. Daniel pre-allocates all buffers once during prefill and reuses them for every subsequent decode step. Zero allocation overhead in the hot loop.

Notice self.paged_attention — one large tensor that holds both K and V ([:, 0] is K, [:, 1] is V). When the cache runs out of space, it grows by KV_CACHE_INCREMENT slots:

elif kv_seq_len >= self.paged_attention.shape[0]:
    self.paged_attention.resize_((
        self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
        2, bsz, n_kv_heads, head_dim,
    ))

Compare this to the KV cache in Step 2’s C code, where you pre-allocated fixed arrays. Same idea, but Daniel’s version grows dynamically and is optimized for GPU memory layout.

In-place RoPE (no allocation)

RH_Q = self.RH_Q
RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
RH_Q[:, :, :, :h].neg_()
Qn *= cos
Qn.addcmul_(RH_Q, sin)

This is RoPE applied in-place — no new tensors are created. The rotation is done by:

  1. Splitting Q into two halves
  2. Swapping them (the rotation)
  3. Negating the first half
  4. Multiplying by cos and adding the rotated version multiplied by sin

In Step 2 you saw this as a simple loop over cos/sin values in C. Here it is the same math, but done entirely in pre-allocated buffers with no memory allocation.

The decode attention itself

if bsz == 1:
    Qn *= self.scalar
    A = torch_matmul(
        Qn, Knn.transpose(2, 3),
        out=self.attention[:, :, :, :cached_len]
    )
    A[:] = torch_nn_functional_softmax(A, dim=-1, dtype=torch.float32)
    A = torch_matmul(A, Vnn, out=Qn)

For batch size 1 (the most common case in interactive generation), Daniel manually implements attention as two matrix multiplies with softmax in between. The out= parameter writes results into pre-allocated buffers — no allocation.

Notice Qn *= self.scalar — the scaling happens before the matmul, not after. This avoids overflow in FP16/BF16. The comment in the code references a llama.cpp issue about this exact problem.

For larger batch sizes, it falls back to PyTorch’s scaled_dot_product_attention:

else:
    if use_sdpa_gqa:
        A = scaled_dot_product_attention(
            Qn, Knn, Vnn,
            attn_mask=attention_mask,
            is_causal=is_causal,
            enable_gqa=True,
        )
TipWhat Daniel is solving here

The decode path is where Daniel’s optimization compounds the most. Each of the ~100-1000 decode steps runs this function. Pre-allocated buffers + in-place math + manual scaling = the hot loop runs with zero allocation overhead and minimal kernel launches.


Lesson 8: Follow one call down — the RMSNorm Triton kernel

We have seen that Unsloth calls fast_rms_layernorm() instead of the stock RMSNorm. But where does the actual speedup come from? To answer that, we need to follow the call one level deeper — into a Triton kernel.

Navigate to unsloth/kernels/rms_layernorm.py in the Unsloth repo.

What is Triton?

Triton is a language for writing GPU kernels in Python. Instead of writing raw CUDA C, you write Python-like code that Triton compiles into efficient GPU machine code. The @triton.jit decorator marks a function as a GPU kernel.

What “fused” means concretely

A standard PyTorch RMSNorm executes as a sequence of separate operations:

  1. x_squared = x * x — kernel launch 1
  2. mean_sq = x_squared.mean() — kernel launch 2
  3. rms = torch.sqrt(mean_sq + eps) — kernel launch 3
  4. x_norm = x / rms — kernel launch 4
  5. output = x_norm * weight — kernel launch 5

Each kernel launch reads data from GPU global memory, computes, and writes back. The data travels between GPU registers and global memory five times.

A fused kernel does all five steps in one kernel body. Data is loaded once, all computations happen in GPU registers/shared memory, and the result is written once. One read, one write.

flowchart TD
    subgraph unfused ["Unfused: 5 separate kernel launches"]
        direction TB
        U1["load x from memory"] --> U2["square → write to memory"]
        U2 --> U3["load → mean → write"]
        U3 --> U4["load → sqrt → write"]
        U4 --> U5["load → divide → write"]
        U5 --> U6["load → multiply weight → write"]
    end
    subgraph fused_kern ["Fused: 1 kernel launch"]
        direction TB
        F1["load x from memory once"]
        F1 --> F2["square, mean, sqrt, divide,\nmultiply weight\nall in registers"]
        F2 --> F3["write result once"]
    end

Why this compounds

RMSNorm appears 56 times per token in a 28-layer model (once before attention, once before FFN, in each layer). If fusing saves 0.01ms per call, that is 0.56ms per token. Over 1000 generated tokens, that is 560ms saved — just from one optimization.

And RMSNorm is only one of several fused kernels. The same approach applies to RoPE, SwiGLU, cross-entropy, and fast linear operations.

The kernel registry

Unsloth exposes a full set of fast operations. These are the functions that Qwen3Attention_fast_forward and the other patched methods call into:

Function What it replaces Where it appears
fast_rms_layernorm torch.nn.RMSNorm Every decoder block (x2)
fast_rope_embedding Stock RoPE Inside attention
fast_linear_forward nn.Linear forward Q/K/V projections, output projection
SwiGLU/GEGLU kernels Gated MLP activation FFN in every block
Cross-entropy kernel F.cross_entropy Training loss
Fast LoRA paths Standard LoRA forward Fine-tuning

Lesson 9: The full stack — from from_pretrained() to Triton

Let us put the entire call chain together. This is the complete path from user code to GPU kernel:

flowchart TD
    U["User: FastQwen3Model.from_pretrained('Qwen/Qwen3-0.6B')"] --> FP["from_pretrained(model_patcher=FastQwen3Model)"]
    FP --> PP["pre_patch() runs"]
    PP --> PATCH["All .forward() methods overwritten"]
    PATCH --> GEN["User: model.generate('Hello')"]
    GEN --> CLM["CausalLM_fast_forward()"]
    CLM --> MOD["LlamaModel_fast_forward()"]
    MOD --> DEC["LlamaDecoderLayer_fast_forward()\n x28 layers"]
    DEC --> ATT["Qwen3Attention_fast_forward()"]
    ATT --> RMS["fast_rms_layernorm()"]
    RMS --> TRITON["Triton kernel\n1 GPU kernel launch"]
    ATT --> ROPE["fast_rope_embedding()"]
    ROPE --> TRITON2["Triton kernel"]
    ATT --> ATTN["run_attention()\nbackend dispatch"]
    DEC --> FFN["FFN with SwiGLU\nfused kernel"]

Every box below pre_patch() is Unsloth’s code. The user only touches the top two boxes. Everything else is transparent.

This is the power of monkey patching as a system design: the user’s code does not change, but the entire execution stack is different.


Lesson 10: Trade-offs Daniel lives with

Monkey patching is powerful, but it creates real engineering costs. Understanding these trade-offs is as important as understanding the optimization.

Fragility across library versions

Monkey patches depend on exact class and method structure. If HuggingFace renames a class, moves a method, or changes a function signature, the patch breaks.

This is why the file has version checks:

if not transformers_version >= Version("4.50.3"):
    raise ImportError(
        "Unsloth: Your transformers version does not support Qwen3..."
    )

Every HuggingFace release is a potential breaking change. Unsloth must test against each new version and update patches accordingly. This is ongoing maintenance cost.

Warm-up cost

Triton kernels need to be compiled on first use. The first inference or training step is slower while kernels are compiled and cached. Subsequent runs reuse the compiled cache — but the first-run penalty is real.

Static-shape assumptions

Some fused kernels only work when shapes are stable. If the head dimension, sequence length growth pattern, or dtype changes mid-run, it can trigger recompilation or fallback to slower paths. The decode path in particular assumes shapes are predictable — that is why it pre-allocates fixed buffers.

Harder debugging

Since execution no longer follows the stock HuggingFace code path, stack traces point to Unsloth functions, not the familiar HF code. When something goes wrong, you have to know about the monkey patch to understand what is actually running. Breakpoints in HF code will never be hit.

TipThe honest engineering take

Monkey patching is not free. It trades simplicity and debuggability for performance. Daniel accepts this trade-off because the performance gains compound across every token in every generation. But it requires discipline: version testing, careful shape management, and clear error messages when things break.


Lesson 11: Connections to Step 2

If you completed Step 2: Run a Model Locally, you already understand every concept that Unsloth optimizes. The difference is the layer of abstraction.

Concept from Step 2 (C code) Where in Unsloth What changed
rmsnorm() — sequential C math fast_rms_layernorm() — Triton kernel Fused into one GPU kernel vs. sequential C loop
RoPE rotation loop with cos/sin fast_rope_embedding() Fused rotation, no per-element kernel launches
attention() — Q @ K, softmax, @ V run_attention() with backend dispatch Selects SDPA/Flash/VarLen based on context
key_cache, value_cache arrays self.paged_attention tensor Paged allocation, pre-allocated temp buffers
FFN with SiLU gating SwiGLU fused kernel Gate + activation + multiply in one kernel
forward() function Monkey-patched .forward() methods Same call, rerouted through optimized paths

Step 2 showed you the raw math in C — what each operation actually computes. This post showed you how production systems optimize that same math for GPU throughput. You need both: understanding the math tells you what is being optimized; understanding Unsloth tells you how and why.


Exercises: code reading practice

These exercises require you to open real source files and trace code. This is the skill you are building.

Exercise 1: Count the patches

Open the Unsloth Qwen3 source file. Find pre_patch(). List every .forward = assignment. How many classes are patched? Which one is NOT a .forward replacement?

Exercise 2: Trace the normalization

In Qwen3Attention_fast_forward, find the two lines where Q and K are normalized. What function is called? Now search the Unsloth codebase for that function’s definition. What file is it in? Does it call a Triton kernel?

Exercise 3: Compare HF to Unsloth

Open HuggingFace’s modeling_qwen3.py. Find Qwen3Attention.forward(). Now open Unsloth’s Qwen3Attention_fast_forward(). Find three concrete differences. Which ones affect performance? Which affect correctness?

Exercise 4: Read the Triton kernel

Navigate to unsloth/kernels/rms_layernorm.py. Find the @triton.jit decorator. Read the kernel body. Can you identify the “fuse boundary” — where multiple math operations that PyTorch would run as separate kernels are combined into one?

Exercise 5: Write a diagnostic monkey patch

Write a Python script that:

  1. Loads Qwen/Qwen3-0.6B using stock HuggingFace AutoModelForCausalLM
  2. Saves a reference to the original forward: original_forward = type(model.model.layers[0].self_attn).forward
  3. Writes a wrapper that prints "LAYER 0 ATTENTION CALLED" before calling original_forward
  4. Patches it: type(model.model.layers[0].self_attn).forward = your_wrapper
  5. Runs a forward pass and verifies the print appears

This proves you understand the mechanism. You do not need a GPU or Triton — this works on CPU.


One-paragraph summary

Unsloth uses monkey patching to overwrite HuggingFace model methods at runtime, so the standard .forward() path is redirected into Unsloth-controlled fast paths. Those fast paths call optimized implementations for hot operations — RMSNorm, RoPE, gated MLPs, cross-entropy, and custom decode/KV-cache logic — many backed by Triton GPU kernels that fuse multiple operations into single kernel launches. The inference decode path goes further: pre-allocated buffers, in-place RoPE, paged KV cache, and manual attention computation eliminate all dynamic allocation from the hot loop. The user still calls from_pretrained() and generate(), but the runtime is completely different.