OpenAI gpt-oss & all model types now supported!

gpt-oss Reinforcement Learning

You can now train OpenAI gpt-oss with RL and GRPO via Unsloth. Unsloth now offers the fastest inference (3x faster), lowest VRAM usage (50% less) and longest context (8x longer) for gpt-oss RL vs. any implementation - with no accuracy degradation. Since reinforcement learning (RL) on gpt-oss isn't yet vLLM compatible, we had to rewrite the inference code from Transformers code to deliver 3x faster inference for gpt-oss at ~21 tokens/s. For BF16, Unsloth also achieves the fastest inference (~30 tokens/s), especially relative to VRAM usage, using 50% less VRAM vs. any other RL implementation. We plan to support our 50% weight sharing feature once vLLM becomes compatible with RL.

With Unsloth, you can train gpt-oss-20b with GRPO on 15GB VRAM and for free on Colab. We introduced embedding offloading which reduces usage by 1GB as well via offload_embeddings. Unloth's new inference runs faster on any GPU including A100, H100 and old T4's. gpt-oss-120b fits nicely on a 80GB VRAM GPU.

Unsloth is the only framework to support 4-bit RL for gpt-oss. All performance gains are due to Unsloth's unique weight sharing, Flex Attention, Standby and custom kernels.

⚡Making Inference Much Faster

Inference is crucial in RL training, since we need it to generate candidate solutions before maximizing some reward function (see here for a more detailed explanation). To achieve the fastest inference speed for gpt-oss without vLLM, we rewrote Transformers inference code and integrated many innovations including custom algorithms like Unsloth Flex Attention, using special flags within torch.compile (like combo kernels). Our new inference code for gpt-oss was evaluated against an already optimized baseline (2x faster than native Transformers).

vLLM does not support RL for gpt-oss since it lacks BF16 training and LoRA support for gpt-oss. Without Unsloth, only training via full precision BF16 works, making memory use 800%+ higher. Most frameworks enable FA3 (Flash Attention 3) by default (which reduces VRAM use & increases speed) but this causes incorrect training loss. See Issue 1797 in the FA3 repo. You must disable FA3 though, since it'll prevent long-context training since FA3 uses O(N) memory usage, whilst naive attention will balloon with O(N^2) usage. So to enable attention sinks to be differentiable, we implemented Unsloth Flex Attention.

We evaluated gpt-oss RL inference by benchmarking BitsandBytes 4-bit and also did separate tests for BF16. Unsloth’s 4-bit inference is ~4x faster, and BF16 is also more efficient, especially in VRAM use.

The best part about Unsloth's gpt-oss RL is that it can work on any GPU, even those that do not support BF16. Our free gpt-oss-20b Colab notebooks use older 15GB T4 GPUs, so the inference examples work well!

🛠️ gpt-oss Flex Attention Issues and Quirks

We had to change our implementation for attention sinks as described here to allow generation to work with left padding. We had to get the logsumexp and apply the sigmoid activation to alter the attention weights like below:

A(X)=σ(1dQKT)VA(X)=exp1dQKTexp1dQKTVLSE=logexp1dQKTAsinks(X)=A(X)σ(LSEsinks)A(X) = \sigma \bigg( \frac{1}{\sqrt{d}}QK^T \bigg)V \\ A(X) = \frac{\exp{\frac{1}{\sqrt{d}}QK^T}}{\sum{\exp{\frac{1}{\sqrt{d}}QK^T}}}V \\ \text{LSE} = \log{\sum{\exp{\frac{1}{\sqrt{d}}QK^T}}} \\ A_{sinks}(X) = A(X) \odot \sigma (\text{LSE} - \text{sinks})

Left padded masking during inference was also a tricky issue to deal with in gpt-oss. We found that we had to not only account for KV Cache prefill during generations of tokens, but also account for a unique amount of pad tokens in each prompt for batch generations which would change the way we would need to store the block mask. Example of such and example can be seen below:

Normal Causal Mask:

   k0 k1 k2 k3 k4   <-- keys
q0  X
q1  X  X
q2  X  X  X
q3  X  X  X  X
q4  X  X  X  X  X   <-- last query row (most important for decoding)

For inference in general case (decoding)

    k0 k1 k2 k3 k4
q0
q1
q2
q3
q4   X  X  X  X  X

If we naively use the same masking strategy, this'll fail:

    k0 k1 k2 k3 k4
q0
q1
q2
q3
q4   X   (note that q4 has q_idx=0 as this is the first query in current setup)

For generation (decoding phase), we usually only care about the last row of the attention matrix, since there’s just one query token attending to all previous key tokens. If we naively apply the causal mask (q_idx ≥ k_idx), this fails as our single query has index 0, while there are n_k key tokens. To fix this, we need an offset in mask creation to decide which tokens to attend. But a naïve approach is slow, since offsets change each step, forcing mask and kernel regeneration. We solved this with cache and compile optimizations.

The harder part is batch generation. Sequences differ in length, so padding complicates mask creation. Flex Attention had a lot of challenges and dynamic masks are tricky. Worse, if not compiled, it falls back to eager attention which is slow and memory-heavy (quadratic vs. linear in sequence length).

Quote from https://github.com/meta-pytorch/attention-gym/issues/15#issuecomment-2284148665

You need to call this with _compile=True. We essentially map your block mask over a full Q_LEN x KV_LEN matrix in order to produce the block mask. Without compile, we need to materialize this full thing, and it can cause OOMs on long sequences.

As well, you need to run flex_attention = torch.compile(flex_attention). Without compile, flex falls back to a non-fused eager implementation that is great for debugging, but it is much slower and materializes the full scores matrix.

Ultimately, the mask must dynamically handle prefill vs decode with the KV Cache, batch and padding tokens per sequence, remain torch.compile friendly, and support sliding windows.

🔍 Flash Attention Investigation

Another interesting direction we explored was trying to integrate Flash Attention. Its advantages are widely recognized, but one limitation is that it does not support attention sinks during the backward pass for gpt-oss. To work around this, we restructured the attention mechanism so that it operates solely on the attention output and the logsumexp values that FlashAttention readily provides. Given these benefits, it seemed like an obvious choice to try.

However, we soon began noticing issues. While the first few layers behaved as expected, the later layers, particularly layers 18 through 24, produced outputs that diverged significantly from the eager-mode implementation in transformers. Importantly, this discrepancy cannot be attributed to error accumulation, since the inputs to each method are identical at every layer. For further validation, we also compared the results against Unsloth FlexAttention.

This needs further investigation into why only the last few layers show such a drastic difference between flash attention implementation vs. the others.

Flash Attention 3 doesn't support the backwards pass for attention sinks

⚠️ Can We Counter Reward Hacking?

The ultimate goal of RL is to maximize some reward (say speed, revenue, some metric). But RL can cheat. When the RL algorithm learns a trick or exploits something to increase the reward, without actually doing the task at end, this is called "Reward Hacking".

It's the reason models learn to modify unit tests to pass coding challenges, and these are critical blockers for real world deployment. Some other good examples are from Wikipedia.

In our free gpt-oss RL notebook we explore how to counter reward hacking in a code generation setting and showcase tangible solutions to common error modes. We saw the model edit the timing function, outsource to other libraries, cache the results, and outright cheat. After countering, the result is our model generates genuinely optimized matrix multiplication kernels, not clever cheats.

🏆Reward Hacking

Some common examples of reward hacking during RL include:

Laziness

RL learns to use Numpy, Torch, other libraries, which calls optimized CUDA kernels. We can stop the RL algorithm from calling optimized code by inspecting if the generated code imports other non standard Python libraries.

Caching & Cheating

RL learns to cache the result of the output and RL learns to find the actual output by inspecting Python global variables.

We can stop the RL algorithm from using cached data by wiping the cache with a large fake matrix. We also have to benchmark carefully with multiple loops and turns.

Cheating

RL learns to edit the timing function to make it output 0 time as passed. We can stop the RL algorithm from using global or cached variables by restricting it's locals and globals. We are also going to use exec to create the function, so we have to save the output to an empty dict. We also disallow global variable access via types.FunctionType(f.__code__, {})

💫From OpenAI's Labs to Your Laptop

gpt-oss is a frontier-class architecture from OpenAI that could power breakthrough AI applications. Until now, training these models with RL was exclusively limited to well-funded labs with H100s to spare.

Today, that changes. You can train gpt-oss-20b with GRPO on a free Google Colab tier here. Free, Frontier Model, Training.

Last updated

Was this helpful?