Long Context gpt-oss Training
We’re excited to introduce Unsloth Flex Attention support for OpenAI gpt-oss training that enables >8× longer context lengths, >50% less VRAM usage and >1.5× faster training vs. all implementations including those using Flash Attention 3 (FA3). Unsloth Flex Attention makes it possible to train with a 60K context length on a 80GB VRAM H100 GPU for BF16 LoRA. Also:
You can now export/save your QLoRA fine-tuned gpt-oss model to llama.cpp, vLLM, Ollama or HF
We fixed gpt-oss training losses going to infinity on float16 GPUs (like T4 Colab)
We fixed gpt-oss implementation issues irrelevant to Unsloth, most notably ensuring that
swiglu_limit = 7.0
is properly applied during MXFP4 inference in transformers
🦥Introducing Unsloth Flex Attention Support
With Unsloth's Flex Attention support, a single 80GB VRAM H100 can handle up to 81K context length with QLoRA and 60K context with BF16 LoRA! These gains are applied to BOTH gpt-oss-20b and gpt-oss-120b! The more context length you use, the more gains you'll get from Unsloth Flex Attention:

In comparison, all other non-Unsloth implementations max out at 9K context length on an 80GB GPU, and can only reach 15K context with FA3. But, FA3 is unsuitable for gpt-oss training since it lacks backward pass support for attention sinks. So if you were previously using FA3 for gpt-oss training, we'd recommend you to not use it for now. Thus, the max context length you can get without Unsloth on 80GB VRAM is ~9K.
Training with Unsloth Flex Attention delivers at least a 1.3× speedup, with gains growing as context length increases, reaching up to 2× faster. Because Flex Attention scales with context, longer sequences yield bigger savings in both VRAM and training time, as described here.
A huge thank you to Rohan Pandey for his Flex Attention implementation, which directly inspired the development of Unsloth's Flex Attention implementation.
🕶️ Attention Sinks
OpenAI's GPT OSS model uses an alternating pattern of sliding window attention, full attention, sliding window attention and so on (SWA, FA, SWA, FA, etc). Each sliding window only attends to 128 tokens (including the current token), so computation is vastly reduced. However, this also means long context retrieval and reasoning becomes useless due to the small sliding window. Most labs fix this by expanding the sliding window to 2048 or 4096 tokens.
OpenAI leveraged Attention Sinks from the Efficient Streaming Language Models with Attention Sinks paper which shows that you can use a small sliding window, except you must add a global attention on the first token! The paper provides a good illustration below:

The paper finds that the attention mechanism seems to assign a lot of weight to the first few tokens (1 to 4), and by removing them during the sliding window operation, these "important" first few tokens disappear, and causes bad long context retrieval.
If we plot log perplexity (higher is worse), and do long context inference after the pretrained model's set context length, we see the perplexity shoots up (not good). However the red line (uses Attention Sinks) stays low, which is very good!

The paper also shows that the Attention Is Off By One method does partially work, except one must also add a few extra sink tokens to get lower perplexities. The paper shows that adding a single sink token that is learnable does remarkably well! And that's what OpenAI did for GPT-OSS!

📐Unsloth's Flex Attention implementation
Flex Attention https://pytorch.org/blog/flexattention/ is extremely powerful as it provides the practitioner 2 customization routes for the attention mechanism - a score modifier (f) and a masking function (M).
The score modifier (f) allows us to edit the attention logits before the softmax operation, and the masking function (M) allows us to skip operations if we don't need them (for eg sliding window attention only sees last 128 tokens).
The trick is Flex Attention provides fast auto generated Triton kernels with arbitrary score modifiers and masking functions!
This means we can use Flex Attention to implement attention sinks! Implementing a single attention sink is provided both in OpenAI's original GPT-OSS repo and HuggingFace's transformers's implementation.
combined_logits = torch.cat([attn_weights, sinks], dim=-1)
probs = F.softmax(combined_logits, dim=-1)
scores = probs[..., :-1]
The above shows we concatenate the sink at the very end of the Q @ K.T
, do the softmax, and remove the last column which was the sink token.
By using some visualization utilities from Flex Attention's Github repo, we can visualize this. Assume the sequence length was 16, and a sliding window of 5. On the left is the last sink column (default implementation), and on the right is if we move the sink location to index 0 (our implementation).
Sink location at the end (default)

Move sink location to index 0

Interesting finding: The official Flex Attention sliding window implementations considers the window size as the number of last tokens PLUS ONE as it includes the current token. The HuggingFace and GPT OSS implementations strictly only sees the last N tokens. Ie the below is from https://pytorch.org/blog/flexattention/ and https://github.com/meta-pytorch/attention-gym:
def sliding_window_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
window_mask = q_idx - kv_idx <= SLIDING_WINDOW
return causal_mask & window_mask
Default Flex Attention (3+1 tokens)

HuggingFace, GPT-OSS (3+0 tokens)

We also confirmed through OpenAI's official GPT-OSS implementation on whether we attend to the last N or N+1 tokens here: https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.py
mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
if sliding_window > 0:
mask += torch.tril(
mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
)

And we see only the last 3 tokens (not 3+1) are attended to! This means instead of using <= SLIDING_WINDOW
, use < SLIDING_WINDOW
(ie use less than, not the equals).
def sliding_window_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
window_mask = q_idx - kv_idx <= SLIDING_WINDOW # Default Flex Attention
window_mask = q_idx - kv_idx < SLIDING_WINDOW # GPT-OSS version
return causal_mask & window_mask
Also since we moved the sink token index to the first, we have to add 1 to the q_idx to index correctly:
def causal_mask_with_sink(batch, head, q_idx, kv_idx):
"""
0 1 2 3 0 1 2 3
0 X X 1 X
1 X X X 2 X X
2 X X X X 3 X X X
"""
# We add (q_idx + 1) since first column is sink token
causal_mask = (q_idx + 1) >= kv_idx
sink_first_column = kv_idx == 0
return causal_mask | sink_first_column
To confirm our index 0 implementation, we verified that the training loss remains consistent with standard Hugging Face runs (without Unsloth Flex Attention), as shown in our graph:

💾NEW: Saving to GGUF, vLLM after gpt-oss training
You can now QLoRA fine-tune gpt-oss and directly save, export, or merge the model to llama.cpp, vLLM, or HF - not just Unsloth. We will be releasing a free notebook hopefully soon.
Previously, any QLoRA fine-tuned gpt-oss model was restricted to running in Unsloth. We’ve removed that limitation by introducing on-demand dequantization of MXFP4 base models (like gpt-oss) during the LoRA merge process. This makes it possible to export your fine-tuned model in bf16 format.
After fine-tuning your gpt-oss model, you can now merge it into a 16-bit format with a single command:
model.save_pretrained_merged(save_directory, tokenizer)
If you prefer to merge the model and push to the hugging-face hub directly instead, you could do so using:
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
To run inference on the merged model, you have a choice between vLLM and Llamacpp among others.
⚙️ Inference Settings
OpenAI recommends these inference settings for both models: temperature=1.0
, top_p=1.0
, top_k=0
♦️Fine-tuning gpt-oss directly
We also added support for directly fine-tuning of gpt-oss models by implementing patches that allow loading the native MXFP4 quantized format. This makes it possible to load the 'openai/gpt-oss' model with less than 24GB of VRAM, and QLoRA fine-tune it. Simply load the model using:
model, tokenizer = FastLanguageModel.from_pretrained(
# model_name = "unsloth/gpt-oss-20b-BF16",
model_name = "unsloth/gpt-oss-20b",
dtype = dtype, # None for auto detection
max_seq_length = max_seq_length, # Choose any for long context!
load_in_4bit = True, # 4 bit quantization to reduce memory
full_finetuning = False, # [NEW!] We have full finetuning now!
# token = "hf_...", # use one if using gated models
)
add a Peft layer using FastLanguageModel.get_peft_model
and run SFT fine-tuning over the Peft model.
🐛Bug Fixes for gpt-oss
We recently collaborated with Hugging Face to resolve inference issues by using OpenAI’s kernels and ensuring that swiglu_limit = 7.0
is correctly applied during MXFP4 inference.
Based on user feedback, we discovered that extended QLoRA training runs (beyond 60 steps) could cause the loss to diverge and eventually error out. This issue only occurred on devices that do not support BF16 and instead fall back to F16 (e.g., T4 GPUs). Importantly, it did not impact QLoRA training on A100 or H100 GPUs, nor LoRA training on f16 GPUs.
After extensive investigation, we’ve now aligned training loss behavior across all GPU setups, including GPUs limited to F16. If you were previously experiencing issues because of this, we recommend using our new updated gpt-oss notebook!

We had to do many many experiments to move float16's training loss curve to be equivalent to bfloat16 machines (blue line). We found the following:
Pure float16 will go to infinity on step 50
We found the down projections in the MoE to have huge outliers
Activations must be saved in bfloat16 or float32
Below shows the absolute magnitude activations for GPT OSS 20B, and some really spike - this will overflow in float16 machines since float16's maximum range is 65504.
We fixed this in Unsloth, so all float16 training works out of the box!

🔢 Implementations for Sink Attention
OpenAI's sink token implementation is provided here. We provide it below:
def sdpa(Q, K, V, S, sm_scale, sliding_window=0):
# sliding_window == 0 means no sliding window
n_tokens, n_heads, q_mult, d_head = Q.shape
assert K.shape == (n_tokens, n_heads, d_head)
assert V.shape == (n_tokens, n_heads, d_head)
K = K[:, :, None, :].expand(-1, -1, q_mult, -1)
V = V[:, :, None, :].expand(-1, -1, q_mult, -1)
S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)
mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
if sliding_window > 0:
mask += torch.tril(
mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
)
QK = torch.einsum("qhmd,khmd->hmqk", Q, K) * sm_scale
QK += mask[None, None, :, :]
QK = torch.cat([QK, S], dim=-1)
W = torch.softmax(QK, dim=-1)
W = W[..., :-1]
attn = torch.einsum("hmqk,khmd->qhmd", W, V)
return attn.reshape(n_tokens, -1)
The HuggingFace transformers implementation is provided here. We also provide it below:
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
combined_logits = torch.cat([attn_weights, sinks], dim=-1)
# This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16
# when training with bsz>1 we clamp max values.
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
scores = probs[..., :-1] # we drop the sink here
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
Last updated
Was this helpful?