⁉️FP16 vs BF16 for RL

Defeating the Training-Inference Mismatch via FP16 https://arxiv.org/pdf/2510.26788 shows how using float16 is better than bfloat16

Float16 vs Bfloat16

There was a paper titled "Defeating the Training-Inference Mismatch via FP16" https://arxiv.org/pdf/2510.26788 showing how using float16 precision can dramatically be better than using bfloat16 when doing reinforcement learning.

In fact the longer the generation, the worse it gets when using bfloat16:

We did an investigation, and DO find float16 to be more stable than bfloat16 with much smaller gradient norms see https://x.com/danielhanchen/status/1985557028295827482 and https://x.com/danielhanchen/status/1985562902531850472

🤯A100 Cascade Attention Bug

As per https://x.com/RichardYRLi/status/1984858850143715759 and https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda, older vLLM versions (before 0.11.0) had broken attention mechanisms for A100 and similar GPUs. Please update vLLM! We also by default disable cascade attention in vLLM during Unsloth reinforcement learning if we detect an older vLLM version.

Different hardware also changes results, where newer and more expensive GPUs have less KL difference between the inference and training sides:

🔥Using float16 in Unsloth RL

To use float16 precision in Unsloth GRPO and RL, you just need to set dtype = torch.float16 and we'll take care of the rest!

from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-4B-Base",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.9, # Reduce if out of memory
    
    dtype = torch.float16, # Use torch.float16, torch.bfloat16
)

Last updated

Was this helpful?