OpenAI gpt-oss & all model types now supported!

Long Context gpt-oss Training

We’re excited to introduce Unsloth Flex Attention support for OpenAI gpt-oss training that enables >8× longer context lengths, >50% less VRAM usage and >1.5× faster training vs. all implementations including those using Flash Attention 3 (FA3). Unsloth Flex Attention makes it possible to train with a 60K context length on a 80GB VRAM H100 GPU for BF16 LoRA. Also:

  • You can now export/save your QLoRA fine-tuned gpt-oss model to llama.cpp, vLLM, Ollama or HF

  • We fixed gpt-oss training losses going to infinity on float16 GPUs (like T4 Colab)

  • We fixed gpt-oss implementation issues irrelevant to Unsloth, most notably ensuring that swiglu_limit = 7.0 is properly applied during MXFP4 inference in transformers

🦥Introducing Unsloth Flex Attention Support

With Unsloth's Flex Attention support, a single 80GB VRAM H100 can handle up to 81K context length with QLoRA and 60K context with BF16 LoRA! These gains are applied to BOTH gpt-oss-20b and gpt-oss-120b! The more context length you use, the more gains you'll get from Unsloth Flex Attention:

In comparison, all other non-Unsloth implementations max out at 9K context length on an 80GB GPU, and can only reach 15K context with FA3. But, FA3 is unsuitable for gpt-oss training since it lacks backward pass support for attention sinks. So if you were previously using FA3 for gpt-oss training, we'd recommend you to not use it for now. Thus, the max context length you can get without Unsloth on 80GB VRAM is ~9K.

Training with Unsloth Flex Attention delivers at least a 1.3× speedup, with gains growing as context length increases, reaching up to 2× faster. Because Flex Attention scales with context, longer sequences yield bigger savings in both VRAM and training time, as described here.

A huge thank you to Rohan Pandey for his Flex Attention implementation, which directly inspired the development of Unsloth's Flex Attention implementation.

🕶️ Attention Sinks

OpenAI's GPT OSS model uses an alternating pattern of sliding window attention, full attention, sliding window attention and so on (SWA, FA, SWA, FA, etc). Each sliding window only attends to 128 tokens (including the current token), so computation is vastly reduced. However, this also means long context retrieval and reasoning becomes useless due to the small sliding window. Most labs fix this by expanding the sliding window to 2048 or 4096 tokens.

OpenAI leveraged Attention Sinks from the Efficient Streaming Language Models with Attention Sinks paper which shows that you can use a small sliding window, except you must add a global attention on the first token! The paper provides a good illustration below:

The paper finds that the attention mechanism seems to assign a lot of weight to the first few tokens (1 to 4), and by removing them during the sliding window operation, these "important" first few tokens disappear, and causes bad long context retrieval.

If we plot log perplexity (higher is worse), and do long context inference after the pretrained model's set context length, we see the perplexity shoots up (not good). However the red line (uses Attention Sinks) stays low, which is very good!

The paper also shows that the Attention Is Off By One method does partially work, except one must also add a few extra sink tokens to get lower perplexities. The paper shows that adding a single sink token that is learnable does remarkably well! And that's what OpenAI did for GPT-OSS!

📐Unsloth's Flex Attention implementation

Flex Attention https://pytorch.org/blog/flexattention/ is extremely powerful as it provides the practitioner 2 customization routes for the attention mechanism - a score modifier (f) and a masking function (M).

The score modifier (f) allows us to edit the attention logits before the softmax operation, and the masking function (M) allows us to skip operations if we don't need them (for eg sliding window attention only sees last 128 tokens).

The trick is Flex Attention provides fast auto generated Triton kernels with arbitrary score modifiers and masking functions!

σ(s×f(QKT+M))\sigma\bigg(s\times\bold{f}(QK^T+\bold{M})\bigg)

This means we can use Flex Attention to implement attention sinks! Implementing a single attention sink is provided both in OpenAI's original GPT-OSS repo and HuggingFace's transformers's implementation.

combined_logits = torch.cat([attn_weights, sinks], dim=-1)
probs = F.softmax(combined_logits, dim=-1)
scores = probs[..., :-1]

The above shows we concatenate the sink at the very end of the Q @ K.T , do the softmax, and remove the last column which was the sink token.

By using some visualization utilities from Flex Attention's Github repo, we can visualize this. Assume the sequence length was 16, and a sliding window of 5. On the left is the last sink column (default implementation), and on the right is if we move the sink location to index 0 (our implementation).

Sink location at the end (default)

Move sink location to index 0

Interesting finding: The official Flex Attention sliding window implementations considers the window size as the number of last tokens PLUS ONE as it includes the current token. The HuggingFace and GPT OSS implementations strictly only sees the last N tokens. Ie the below is from https://pytorch.org/blog/flexattention/ and https://github.com/meta-pytorch/attention-gym:

def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW 
    return causal_mask & window_mask

Default Flex Attention (3+1 tokens)

HuggingFace, GPT-OSS (3+0 tokens)

We also confirmed through OpenAI's official GPT-OSS implementation on whether we attend to the last N or N+1 tokens here: https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.py

mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
if sliding_window > 0:
    mask += torch.tril(
        mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
    )

And we see only the last 3 tokens (not 3+1) are attended to! This means instead of using <= SLIDING_WINDOW, use < SLIDING_WINDOW (ie use less than, not the equals).

def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW # Default Flex Attention
    window_mask = q_idx - kv_idx <  SLIDING_WINDOW # GPT-OSS version
    return causal_mask & window_mask

Also since we moved the sink token index to the first, we have to add 1 to the q_idx to index correctly:

def causal_mask_with_sink(batch, head, q_idx, kv_idx):
    """
      0 1 2 3     0 1 2 3
    0 X X       1   X
    1 X X X     2   X X
    2 X X X X   3   X X X
    """
    # We add (q_idx + 1) since first column is sink token
    causal_mask = (q_idx + 1) >= kv_idx
    sink_first_column = kv_idx == 0
    return causal_mask | sink_first_column

To confirm our index 0 implementation, we verified that the training loss remains consistent with standard Hugging Face runs (without Unsloth Flex Attention), as shown in our graph:

💾NEW: Saving to GGUF, vLLM after gpt-oss training

You can now QLoRA fine-tune gpt-oss and directly save, export, or merge the model to llama.cpp, vLLM, or HF - not just Unsloth. We will be releasing a free notebook hopefully soon.

Previously, any QLoRA fine-tuned gpt-oss model was restricted to running in Unsloth. We’ve removed that limitation by introducing on-demand dequantization of MXFP4 base models (like gpt-oss) during the LoRA merge process. This makes it possible to export your fine-tuned model in bf16 format.

After fine-tuning your gpt-oss model, you can now merge it into a 16-bit format with a single command:

model.save_pretrained_merged(save_directory, tokenizer)

If you prefer to merge the model and push to the hugging-face hub directly instead, you could do so using:

model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)

To run inference on the merged model, you have a choice between vLLM and Llamacpp among others.

⚙️ Inference Settings

OpenAI recommends these inference settings for both models: temperature=1.0, top_p=1.0, top_k=0

Llama.cpp
  1. Obtain the latest llama.cpp on GitHub here. You can follow the build instructions below as well. Change -DGGML_CUDA=ON to -DGGML_CUDA=OFF if you don't have a GPU or just want CPU inference.

    apt-get update
    apt-get install pciutils build-essential cmake curl libcurl4-openssl-dev -y
    git clone https://github.com/ggml-org/llama.cpp
    cmake llama.cpp -B llama.cpp/build \
        -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=ON -DLLAMA_CURL=ON
    cmake --build llama.cpp/build --config Release -j --clean-first --target llama-cli llama-gguf-split
    cp llama.cpp/build/bin/llama-* llama.cp
  2. Convert and quantize the merged model:

    python3 llama.cpp/convert_hf_to_gguf.py gpt-oss-finetuned-merged/ --outfile gpt-oss-finetuned.gguf
    llama.cpp/llama-quantize gpt-oss-finetuned.gguf  gpt-oss-finetuned-Q8_0.gguf Q8_0
  3. Run inference on the quantized model:

    llama.cpp/llama-cli --model gpt-oss-finetuned-Q8_0.gguf \
        --jinja -ngl 99 --threads -1 --ctx-size 16384 \
        --temp 1.0 --top-p 1.0 --top-k 0 \
         -p "The meaning to life and the universe is"
SGLang
  1. Build SGLang from source:

    # build from source
    git clone https://github.com/sgl-project/sglang
    cd sglang
    pip3 install pip --upgrade
    pip3 install -e "python[all]"
    
    # ROCm 6.3
    pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/rocm6.3
    git clone https://github.com/triton-lang/triton
    cd python/triton_kernels
    pip3 install .
    
    # hopper
    pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu126
    pip3 install sgl-kernel==0.3.2
    
    # blackwell cu128
    pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
    pip3 install https://github.com/sgl-project/whl/releases/download/v0.3.2/sgl_kernel-0.3.2+cu128-cp39-abi3-manylinux2014_x86_64.whl
    
    # blackwell cu129
    pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu129
    pip3 install https://github.com/sgl-project/whl/releases/download/v0.3.2/sgl_kernel-0.3.2-cp39-abi3-manylinux2014_x86_64.whl
  2. Launch SGLang server:

    python3 -m sglang.launch_server --model-path ./gpt-oss-finetuned-merged/
  3. Run inference:

    import requests
    from sglang.utils import print_highlight
    
    url = f"http://localhost:8000/v1/chat/completions"
    
    data = {
        "model": "gpt-oss-finetuned-merged",
        "messages": [{"role": "user", "content": "What is the capital of France?"}],
    }
    
    response = requests.post(url, json=data)
    print_highlight(response.json())

♦️Fine-tuning gpt-oss directly

We also added support for directly fine-tuning of gpt-oss models by implementing patches that allow loading the native MXFP4 quantized format. This makes it possible to load the 'openai/gpt-oss' model with less than 24GB of VRAM, and QLoRA fine-tune it. Simply load the model using:

model, tokenizer = FastLanguageModel.from_pretrained(
    # model_name = "unsloth/gpt-oss-20b-BF16", 
    model_name = "unsloth/gpt-oss-20b",
    dtype = dtype, # None for auto detection
    max_seq_length = max_seq_length, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

add a Peft layer using FastLanguageModel.get_peft_model and run SFT fine-tuning over the Peft model.

🐛Bug Fixes for gpt-oss

We recently collaborated with Hugging Face to resolve inference issues by using OpenAI’s kernels and ensuring that swiglu_limit = 7.0 is correctly applied during MXFP4 inference.

Based on user feedback, we discovered that extended QLoRA training runs (beyond 60 steps) could cause the loss to diverge and eventually error out. This issue only occurred on devices that do not support BF16 and instead fall back to F16 (e.g., T4 GPUs). Importantly, it did not impact QLoRA training on A100 or H100 GPUs, nor LoRA training on f16 GPUs.

After extensive investigation, we’ve now aligned training loss behavior across all GPU setups, including GPUs limited to F16. If you were previously experiencing issues because of this, we recommend using our new updated gpt-oss notebook!

We had to do many many experiments to move float16's training loss curve to be equivalent to bfloat16 machines (blue line). We found the following:

  1. Pure float16 will go to infinity on step 50

  2. We found the down projections in the MoE to have huge outliers

  3. Activations must be saved in bfloat16 or float32

Below shows the absolute magnitude activations for GPT OSS 20B, and some really spike - this will overflow in float16 machines since float16's maximum range is 65504.

We fixed this in Unsloth, so all float16 training works out of the box!

🔢 Implementations for Sink Attention

OpenAI's sink token implementation is provided here. We provide it below:

def sdpa(Q, K, V, S, sm_scale, sliding_window=0):
    # sliding_window == 0 means no sliding window
    n_tokens, n_heads, q_mult, d_head = Q.shape
    assert K.shape == (n_tokens, n_heads, d_head)
    assert V.shape == (n_tokens, n_heads, d_head)
    K = K[:, :, None, :].expand(-1, -1, q_mult, -1)
    V = V[:, :, None, :].expand(-1, -1, q_mult, -1)
    S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)
    mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
    if sliding_window > 0:
        mask += torch.tril(
            mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
        )
    QK = torch.einsum("qhmd,khmd->hmqk", Q, K) * sm_scale
    QK += mask[None, None, :, :]
    QK = torch.cat([QK, S], dim=-1)
    W = torch.softmax(QK, dim=-1)
    W = W[..., :-1]
    attn = torch.einsum("hmqk,khmd->qhmd", W, V)
    return attn.reshape(n_tokens, -1)

The HuggingFace transformers implementation is provided here. We also provide it below:

def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)
    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
    combined_logits = torch.cat([attn_weights, sinks], dim=-1)

    # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16
    # when training with bsz>1 we clamp max values.

    combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
    probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
    scores = probs[..., :-1]  # we drop the sink here
    attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output, attn_weights

Last updated

Was this helpful?