🎱FP8 Reinforcement Learning

Train reinforcement learning (RL) and GRPO in FP8 precision with Unsloth.

We're introducing FP8-precision training for RL, making FP8 GRPO now possible on consumer GPUs (RTX 40, 50 etc). DeepSeek-R1 demonstrated how powerful FP8 can be and with Unsloth, Qwen3-1.7B FP8 GRPO now works on just 5GB of VRAM.

Faster RL inference is critical as it's the most compute-intensive workload in RL. We collabed with TorchAO from PyTorch to enable performance gains with no loss in accuracy.

  • ~1.4× faster RL inference via vLLM • 2x longer context vs. BF16 and FP16

  • 60% less VRAM and 10× longer context than other FP8 RL implementations

  • Unsloth is the only framework to make FP8 RL LoRA work on consumer GPUs (e.g. NVIDIA GeForce RTX 40 and 50 Series). Also works on H100, H200, B200 etc.

  • Use load_in_fp8 = True within FastLanguageModel to enable FP8 RL.

  • Though Qwen3-8B fits in 16GB VRAM, free Colab NVIDIA Tesla T4 GPUs don’t support FP8. So our notebooks use 24GB L4 GPUs which fits Qwen3-14B.

Notebooks: Qwen3-8B FP8 GRPO and Llama-3.2-1B FP8 GRPO

Our FP8 support uses Unsloth’s weight-sharing feature, reducing VRAM use by another 50%, enabling 10× more context with no accuracy loss. We use vLLM for fast inference and, our techniques like Unsloth Standby and Flex Attention to further reduce VRAM use. TorchAO enables universal on the fly FP8, so Llama, Gemma, Mistral & more work. We’ve also uploaded most FP8 models (including Qwen3).

Reward plot shows FP8 following the same trend as BF16

🌻FP8 vs BF16 Training

Research shows that FP8 training can largely match BF16 accuracy and if you serve models in FP8, training and serving in the same precision helps preserve accuracy. Also FP8 vs BF16 yields 1.6x higher throughput on H100s and has 2x lower memory usage.

Weight scales & FP8 types

Quantized training stores a low-precision weight (e.g., FP8) plus a higher-precision scale (FP16/BF16/FP32). You approximately recover the original weight via: original_weight ≈ quantized_weight * weight_scale

The scale maps the weight’s range into FP8’s representable range. More scales usually improve accuracy, but scales cost extra high-precision memory, so it’s a tradeoff. DeepSeek R1, for instance, mostly favors block quantization.

There are 3 common FP8 types as defined by vLLM's llm-compressor. We benchmarked Qwen3-8B on all 3 types, and also checked throughput, MMLU Pro and GQPA Diamond. We find FP8 Block-Wise or Per-Channel (-FP8-Dynamic) is the best in terms of accuracy and throughput.

Type
Throughput
MMLU Pro
GQPA Diamond

Bfloat16 Baseline

11,367

62.04%

28.79%

Block-wise

Scales per block (128X128)

12,041

62.37%

29.29%

Per-Channel

1 scale per row or column

12,963

61.89%

31.82%

Per-Tensor

1 scale for the whole tensor

13,681

61.83%

27.78%

FP8 Performance Benchmarks

Unsloth FP8 RL inference via vLLM is generally 1.4x faster than BF16. You may see even more speed improvements if the model is larger!

Accuracy Training loss Benchmarks

We tested multiple models including Qwen3-4B, 8B, 14B, Llama 3.2 1B, 3B, Qwen3-VL-2B, Qwen3-VL 4B and many more. All were trained both in BF16 and FP8. As seen in the plots, the loss curves during SFT for BF16 and FP8 closely track each other. There isn’t much to choose between the two data types in terms of training loss:

For GRPO specifically, due to generation differences, the goal is to see if the reward plots at least match up and not diverge (sometimes for eg Qwen3-14B runs might not be exactly similar)

⛩️Inference = 96% of RL training

In RL, we have to call the LLM / VLM to generate some possible candidate solutions to some run, then we score each possible solution and reward good solutions, and penalize bad answers. To achieve maximum efficiency, we m

ust make inference nearly 100% of the training run. In Unsloth, we managed to make training take only <4% of the entire RL run, with 96% being purely vLLM inference.

For example for Qwen-3-8B, which is 1.15x faster on shorter sequence lengths, vLLM FP8 itself for inference (without training) throughput is also 1.15x faster. We see our RL run in Unsloth attains also 1.15x faster on tokens processed, showing how training overhead is negligible in Unsloth.

🔢60% less memory usage

In theory, you’d expect memory savings to roughly equal to the model’s weight memory, because: optimizer states are still stored in high precision and activations are also stored in high precision (for now). Our findings match the theory. For LoRA fine-tuning, we observed: ~30 GB saved for Qwen3-32B, ~14 GB saved for Qwen2.5-14B and ~8 GB saved for Qwen3-8B

For BF16 LoRA fine-tuning on Qwen3-32B, we were ooming at higher batch sizes and had to shrink the batch. The FP8 variant had no such issues, and we could use larger batch sizes without OOMing.

Also reminder in Unsloth we share vLLM's memory space for the weights as introduced in Memory Efficient RL - we have bought this trick over to the FP8 domain!

80GB GPU
Inference Engine
Training Engine

Model Weights

8GB SHARED FP8

<<< SHARED

Multi-purpose

72GB space

KV Cache

Activations, Gradients, Optimizer States

To enable Unsloth Standby for FP8 (or BF16) RL, simply add the below to all RL / GRPO training runs before any Unsloth import:

import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"

How to use FP8 RL / installation

Simply update Unsloth or install Unsloth in a new virtual environment for H100, L4, RTX 50x, RTX 40x, H200s, B200s, and any NVIDIA GPU (consumer or data center grade) released after the RTX 4090.

To update Unsloth: pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zooOr make a new environment:

python -m venv unsloth_env
source unsloth_env/bin/activate

pip install unsloth vllm
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu128 --force-reinstall
pip install --pre fbgemm-gpu fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/cu128 --force-reinstall
pip install --upgrade numba numpy

Then use load_in_fp8 = True and you're good to go! We'll auto map the model name to the Float8 variant, or we'll on the fly convert the model to Float8!

import os
os.environ['UNSLOTH_VLLM_STANDBY'] = "1" # Unsloth standby saves 30%+ memory for RL
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    load_in_fp8 = True, # Float8 RL / GRPO!
)

For example on a RTX 5090 (reminder to set os.environ["UNSLOTH_VLLM_STANDBY"] = "1" )

Then use our 2 FP8 notebooks for RL:

💿Implementing FP8 Training

Our first reference point was transformers, which already supports FP8 in a couple of ways. One of them is a block-quantized matmul implementation: when a layer receives 16‑bit activations, it quantizes them and passes them to a custom FP8 matmul kernel. After wiring this up and benchmarking on an NVIDIA H100, we saw the opposite of what we wanted: fine-tuning became about 4× slower than standard BF16 fine-tuning.

🔥TorchAO Collab

So we worked with the TorchAO team (huge thanks to Andrew) to incorporate TorchAO’s FP8 support into our RL workloads and saw around 1.4× faster throughput and up to 60% less model memory usage. At a high level:

  • We store the frozen LoRA weights in FP8.

  • During the forward pass, we apply dynamic FP8 quantization to the input activations, while keeping the trainable LoRA adapters in BF16.

  • These FP8 weights share the same buffers as the vLLM model weights, so there’s only a single FP8 copy of the model in memory at any time (no “double model” memory overhead).

  • In the backward pass, we dequantize the LoRA weights so all gradient computation is done in BF16 for better accuracy.

This general setup works across all supported RL algorithms, including GSPO, Dr. GRPO, PPO, and DPO.

TorchAO provides PyTorch-native FP8 support for both training and inference, offering a variety of scaling granularities including tensorwise, row-wise, and 128x128 blockwise (prototype). TorchAO’s FP8 support can improve inference throughput by up to 1.64x at 27B scale with row-wise scaling granularity. For more details, visit the TorchAO FP8 README.

TorchAO’s block-quantized FP8 matmul

We used TorchAO’s block‑quantized FP8 matmul implementation which provided:

  • 80% of BF16 throughput

  • Without degrading loss or training stability

So for a while, this became our default FP8 matmul backend, until FBGEMM caught up - we know default to using FBGEMM's implementation, if your GPU supports it! The current version of Unsloth can automatically choose the best backend based on what’s installed. If you have the right packages, you don’t have to leave performance on the table 🙂

PS: We also experimented with DeepSeek’s DeepGEMM, but couldn’t get it fully integrated end‑to‑end to run clean, apples‑to‑apples comparisons.

🐦On the fly TorchAO FP8 quantization

Massive thanks to Andrew from TorchAO, Unsloth FP8 RL also lets you quantize the model on the fly by doing quantization within the model load time and passing that on to vLLM. This way, you need not explicitly quantize the model yourself (we handle it for you). You can do this by setting load_in_fp8 = True in the model load arguments, and will do offline FP8 if we don't find a suitable pre-quantized checkpoint.

from unsloth import FastLanguageModel
fp8_model = FastLanguageModel.from_pretrained(
    "unsloth/Llama-3.3-70B-Instruct", # Can be any model name!
    load_in_fp8 = True, # Can be "block" for block FP8, True for row FP8, False
)

🎉Unsloth FP8 uploads

For convenience, we uploaded FP8 Dynamic and FP8 Block models on Hugging Face. You can use them for FP8 training or also efficient & fast serving/deployment via vLLM/SGLang etc.

FP8 Dynamic offers slightly faster training and lower VRAM usage than FP8 Block, but with a small trade-off in accuracy. See here for our full list of FP8 quants, but here the most popular ones:

Model
FP8 uploads

Qwen3 (2507)

4B Instruct — FP8 4B Thinking — FP8 30B-A3B Instruct — FP8 30B-A3B Thinking — FP8

Qwen3-VL

4B Instruct — FP8 4B Thinking — FP8 8B Instruct — FP8 8B Thinking — FP8

Llama 3.1

8B Instruct — Dynamic · Block 8B Base — Dynamic · Block 70B — Dynamic · Block

Qwen3

0.6B — FP8 1.7B — FP8 4B — FP8 8B — FP8 14B — FP8 32B — FP8

Llama 3.3

70B — Dynamic · Block

Llama 3.2

1B Base — Dynamic · Block 1B Instruct — Dynamic · Block 3B Base — Dynamic · Block 3B Instruct — Dynamic · Block

Granite 4.0

h-tiny — FP8 Dynamic h-small — FP8 Dynamic

Magistral Small

Mistral Small 3.2

Gemma 3

270m — FP8 1B — FP8 4B — FP8 12B — FP8 27B — FP8

💁Acknowledgements

Huge thanks to the entire PyTorch and TorchAO team for their help and collaboration! A huge thank you especially to: Andrew Or, Jerry Zhang, Supriya Rao, Scott Roy and Mergen Nachin for helping on many discussions on FP8 RL, and on helping to integrate it into Unsloth! Also thanks to the Executorch team as well!

Last updated

Was this helpful?