OpenAI gpt-oss & all model types now supported!

Memory Efficient RL

We're excited to introduce more efficient reinforcement learning (RL) in Unsloth with multiple algorithmic advancements:

  • 1.2 to 1.7x increased context lengths with no slowdown and no extra memory usage!

  • 10% faster RL training runs with revamped kernels and async data movements

  • 2x faster torch.compile times during model loading

Unsloth already increases RL training speed, context window and reduces VRAM usage by 50–90% vs. all other setups with FA2, but now Unsloth's Standby feature improves this even further. Now, Qwen3-32B LoRA 16-bit can attain 6,144 context lengths vs 3,600 (1.7x longer) before on 1xH100 80GB GPU. Llama-3.1-8B QLoRA 4bit can attain 47,500 lengths vs 42,000 before (1.13x longer).

We made RL runs 10% faster through various kernel optimizations, and removed the LoRA communication channel between the CPU and GPU when switching from training to inference mode. Finally, torch.compile makes vLLM's inference rollout faster by 10%, but compilation time can be slow.

How to enable optimizations

To enable Unsloth's Standby feature, set the environment variable UNSLOTH_VLLM_STANDBY before any Unsloth import. Then set gpu_memory_utilization = 0.95

import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"

from unsloth import FastLanguageModel
import torch
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048, # Can increase for longer reasoning traces
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True,
    max_lora_rank = 32, # Larger rank = smarter, but slower
    gpu_memory_utilization = 0.95,
)

🎓No more gpu_memory_utilization!

With Unsloth's new RL improvements, you NEVER have to worry about tuning or setting gpu_memory_utilization ever again - simply set it to 90% or 95% of GPU utilization - 100% sadly won't work since some space is needed for small tensors. Previously one had to tune it from 30% to 95% - no more now! Set it to the maximum and Unsloth will handle the rest!

⁉️Why does RL use so much memory?

GRPO (and many RL variants) rely heavily on generation which is primarily powered by vLLM, but RL comes with a steep cost: it requires constant GPU memory for weights, activations, and the KV Cache.

Inference takes a lot of VRAM

Training also uses VRAM!

This means RL needs to keep 2 sets of VRAM / memory on the GPU at the same time:

  1. Inference engine (has model weights, KV cache)

  2. Training engine (has model weights, activations, gradients, optimizer states)

Current RL frameworks have to split 50/50 for a 80GB GPU with 50% for inference and 50% for training:

80GB GPU
Inference Engine (50%)
Training Engine (50%)

Model Weights

16GB

16GB

KV Cache

24GB

Activations, Gradients, Optimizer States

24GB

Previous Unsloth versions already smartly optimizes the above, as we share vLLM's weight space directly which removes the double memory usage of the model weights. This frees up 16GB of space for example which can be used to increase context length or speed of generation.

80GB GPU
Inference Engine (50%)
Training Engine (50%)

Model Weights

16GB SHARED

<<< SHARED

KV Cache

24GB + 8GB= 32GB

Activations, Gradients, Optimizer States

24GB + 8GB=32GB

🦥Unsloth Standby

But we can go further - we first note RL does inference then training then inference then training etc.

This means the memory space for inference and training can in theory be separated - this is where vLLM's sleep mode feature comes in!

  1. level = 1 copies weights to the CPU, deletes KV cache

  2. level = 2 deletes weights, deletes KV cache

But reminder in Unsloth we share vLLM's memory space for the weights, this means we need a new way to delete the KV cache, and ignore weights, and we call this Unsloth Standby.

80GB GPU
Inference Engine
Training Engine

Model Weights

16GB SHARED

<<< SHARED

Multi-purpose

64GB space

KV Cache

Activations, Gradients, Optimizer States

To enable it, simply add the below to all RL / GRPO training runs with Unsloth at the start:

import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"

🧪Performance Experiments

Here you will find out how we benchmarked memory usage and context length for GRPO. Note that we do 2 generations per prompt because for GRPO to work, we need at least 2 generations for which we calculate sample mean and variance. Without 2 generations, the standard deviation of one sample is zero. This causes the advantages which contains the term, (reward - mean)/std to be undefined.

Z=riμ1n(riμ)2Zn=1=r1μ11(r1μ)2=00=undefinedZ=\frac{r_i - \mu}{\sqrt{\frac{1}{n}\sum(r_i-\mu)^2}} \\ Z_{n=1}=\frac{r_1 - \mu}{\sqrt{\frac{1}{1}\sum(r_1-\mu)^2}}=\frac{0}{0}=\text{undefined}

This means for GRPO specifically, a maximum context length of 6,144 for Qwen-3 32B is actually 6,144 multiplied by 2 generations ie 12,288 in length.

We provide experiments for Llama-3.1 8B on both LoRA (16bit) and QLoRA (4bit) below:

If you notice any training time differences, it isn’t much. In our apples to apples comparison we noticed <1% training time slowdowns or even speedups which can be attributed to margin of error.

We also theorize speedups are possible due to reduce memory pressure, so there might be less memory cleanup.

In the above image, you see the difference between baseline and standby mode on a single T4 GPU as well for Qwen 3 4B. We can stretch the vllm's gpu_memory_utilisation to as high as 0.95 without worrying that it'd effect training. This means you can fit higher context length sequences and more sequences would be processed. In the first case, for example, we have enough memory to fit and process 32K length sequences provided training allows where as previously, any inputs longer than 2K would potentially not fit in and end up causing OOM on colab.

Experiments
Config
Status
GPU Memory usage
Comments

standby True

vllm_gpu_util 0.95

num_gen 2

grad_acc_steps 2

Runs for 40 steps/ 40 minutes

14.5 GiB (set by vllm_gpu_util)

Enough to fit in 32K KVCache with chunk of 2-4K or say 16K KVCache + 16K chunks

standby True

vllm_gpu_util 0.9

num_gen 2

grad_acc_steps 2

Runs 32 steps in 40 m

13.8 GiB (set by…)

Approx enough to fit in ~28K KVCache with chunk of 2-4K or say 15K KVCache + 15K chunks

standby False

vllm_gpu_util 0.9

num_gen 2

grad_acc_steps 2

model loads but can’t train because even batch size of 1 doesn’t fit

OOM

standby False

vllm_gpu_util 0.8

num_gen 2

grad_acc_steps 2

model loads but can’t train because even batch size of 1 doesn’t fit

OOM

standby False

vllm_gpu_util 0.7

num_gen 2

grad_acc_steps 2

Trains fine

28 steps take 39min

~15.1GiB

any input slightly longer will result in OOM on colab

standby True

vllm_gpu_util 0.7

num_gen 2

grad_acc_steps 2

Trains fine

29 steps take 40min

13GiB but most of the time around 10-11GB

At the same config, we save 2GiB aka 15% memory here. Can be higher for longer sequences

H100 Experiments

Model
GPU
Seq Len
Num Generations
Grad Acc Steps

Qwen2.5-14B-Instruct

NVIDIA H100 80GB PCIe

32,768

8

4

In our collapsible results below, you can see there is a 9GiB difference in the peak memory used (note that 90% of the time, the GPU memory usage is equal to the peak memory in our case). To put things into perspective, using TRL and LoRA we were able to only fine-tune an 8B param model with a context length of 1024 at max. Anything with higher sequence length (with similar configuration) results in the process failing with OOM.

Click for Unsloth Standby Mode vs. no Standby Benchmarks
Standy mode enabled:

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  32249 MiB |  43042 MiB | 128336 GiB | 128305 GiB |
|       from large pool |  31415 MiB |  42165 MiB | 127204 GiB | 127173 GiB |
|       from small pool |    834 MiB |   1184 MiB |   1132 GiB |   1131 GiB |
|---------------------------------------------------------------------------|
| Active memory         |  32249 MiB |  43042 MiB | 128336 GiB | 128305 GiB |
|       from large pool |  31415 MiB |  42165 MiB | 127204 GiB | 127173 GiB |
|       from small pool |    834 MiB |   1184 MiB |   1132 GiB |   1131 GiB |
|---------------------------------------------------------------------------|
| Requested memory      |  32199 MiB |  42987 MiB | 128176 GiB | 128145 GiB |
|       from large pool |  31364 MiB |  42110 MiB | 127047 GiB | 127016 GiB |
|       from small pool |    834 MiB |   1184 MiB |   1129 GiB |   1128 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |  37644 MiB |  47504 MiB | 705806 MiB | 668162 MiB |
|       from large pool |  36376 MiB |  46588 MiB | 682818 MiB | 646442 MiB |
|       from small pool |   1268 MiB |   1284 MiB |  22988 MiB |  21720 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory | 713142 KiB |   4633 MiB | 103206 GiB | 103205 GiB |
|       from large pool | 525312 KiB |   4594 MiB | 101923 GiB | 101922 GiB |
|       from small pool | 187830 KiB |    250 MiB |   1283 GiB |   1283 GiB |
|---------------------------------------------------------------------------|
| Allocations           |    3460    |    4809    |   15606 K  |   15603 K  |
|       from large pool |     395    |     563    |    2812 K  |    2811 K  |
|       from small pool |    3065    |    4270    |   12794 K  |   12791 K  |
|---------------------------------------------------------------------------|
| Active allocs         |    3460    |    4809    |   15606 K  |   15603 K  |
|       from large pool |     395    |     563    |    2812 K  |    2811 K  |
|       from small pool |    3065    |    4270    |   12794 K  |   12791 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |     913    |     920    |   13260    |   12347    |
|       from large pool |     279    |     305    |    1766    |    1487    |
|       from small pool |     634    |     642    |   11494    |   10860    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |     422    |     628    |    4766 K  |    4765 K  |
|       from large pool |      66    |      92    |    1290 K  |    1289 K  |
|       from small pool |     356    |     555    |    3476 K  |    3475 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|


Without Standby:

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  32711 MiB |  52084 MiB | 142756 GiB | 142724 GiB |
|       from large pool |  31877 MiB |  51207 MiB | 141499 GiB | 141467 GiB |
|       from small pool |    834 MiB |   1184 MiB |   1257 GiB |   1256 GiB |
|---------------------------------------------------------------------------|
| Active memory         |  32711 MiB |  52084 MiB | 142756 GiB | 142724 GiB |
|       from large pool |  31877 MiB |  51207 MiB | 141499 GiB | 141467 GiB |
|       from small pool |    834 MiB |   1184 MiB |   1257 GiB |   1256 GiB |
|---------------------------------------------------------------------------|
| Requested memory      |  32572 MiB |  51658 MiB | 141898 GiB | 141866 GiB |
|       from large pool |  31738 MiB |  50780 MiB | 140644 GiB | 140613 GiB |
|       from small pool |    833 MiB |   1184 MiB |   1253 GiB |   1252 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |  49552 MiB |  52188 MiB |  86354 MiB |  36802 MiB |
|       from large pool |  48320 MiB |  51300 MiB |  84740 MiB |  36420 MiB |
|       from small pool |   1232 MiB |   1232 MiB |   1614 MiB |    382 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Allocations           |    3460    |    4809    |   17440 K  |   17437 K  |
|       from large pool |     395    |     564    |    2742 K  |    2741 K  |
|       from small pool |    3065    |    4270    |   14698 K  |   14695 K  |
|---------------------------------------------------------------------------|
| Active allocs         |    3460    |    4809    |   17440 K  |   17437 K  |
|       from large pool |     395    |     564    |    2742 K  |    2741 K  |
|       from small pool |    3065    |    4270    |   14698 K  |   14695 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

The image below shows how standby compares against non standby training on Unsloth. It is averaged over 3 runs to make sure the metrics aren’t noisy. In fact, if you zoom in close enough, you’d see that enabling standby makes it faster as well, probably due to less memory pressure.

Previous A100 40GB experiments

In our previous experiments on A100 40GB GPU with Qwen-2.5-3b-instruct and 8 generations per sample, we observed that without standby, the GRPO training (model loaded in 16bit, LoRA, only weights trainable), we could only fit 6K sequence lengths. With our standby feature, we were able to fit 10K and beyond! For comparison TRL can only give you context lengths of up to 1K while holding the same batch size.

🎉Other optimizations

We now select better compilation flags and reduce compile times by 50% or more. We also managed to dynamically patch any vLLM version to handle gc.collect better for backwards compatibility reasons, as inspired from this vLLM pull request

We optimized torch.compile flags and tried turning on some flags - unfortunately combo_kernels and multi_kernel could not function correctly and coordinate_descent_tuning made autotuning all kernels dramatically slower from in under a minute to over 13 minutes and more.

📚GRPO Notebooks

All our GRPO notebooks have Unsloth Standby on by default and all optimizations! See https://docs.unsloth.ai/get-started/unsloth-notebooks for all our GRPO notebooks, or try the below:

Last updated

Was this helpful?